Use TransactionWriter in other component SQLite (#1209)

* Use TransactionWriter on other component SQLites

* Fix sync API tests

* Fix panic in media API

* Fix a couple of transactions

* Fix wrong query, add some logging output

* Add debug logging into StoreEvent

* Adjust InsertRoomNID

* Update logging
This commit is contained in:
Neil Alexander 2020-07-21 15:48:21 +01:00 committed by GitHub
parent 1d72ce8b7a
commit b6bc132485
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 439 additions and 245 deletions

View file

@ -18,6 +18,8 @@ import (
"context"
"database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const accountDataSchema = `
@ -48,12 +50,16 @@ const selectAccountDataByTypeSQL = "" +
"SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
type accountDataStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertAccountDataStmt *sql.Stmt
selectAccountDataStmt *sql.Stmt
selectAccountDataByTypeStmt *sql.Stmt
}
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(accountDataSchema)
if err != nil {
return
@ -73,8 +79,10 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
func (s *accountDataStatements) insertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) (err error) {
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
return
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err := txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
return err
})
}
func (s *accountDataStatements) selectAccountData(

View file

@ -20,6 +20,7 @@ import (
"time"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
@ -57,6 +58,8 @@ const selectNewNumericLocalpartSQL = "" +
// TODO: Update password
type accountsStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
@ -65,6 +68,8 @@ type accountsStatements struct {
}
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(accountsSchema)
if err != nil {
return
@ -94,12 +99,15 @@ func (s *accountsStatements) insertAccount(
createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt
var err error
if appserviceID == "" {
_, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
} else {
_, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
}
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
var err error
if appserviceID == "" {
_, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
} else {
_, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
}
return err
})
if err != nil {
return nil, err
}

View file

@ -19,6 +19,7 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const profilesSchema = `
@ -46,6 +47,8 @@ const setDisplayNameSQL = "" +
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
type profilesStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertProfileStmt *sql.Stmt
selectProfileByLocalpartStmt *sql.Stmt
setAvatarURLStmt *sql.Stmt
@ -53,6 +56,8 @@ type profilesStatements struct {
}
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(profilesSchema)
if err != nil {
return
@ -75,8 +80,10 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
func (s *profilesStatements) insertProfile(
ctx context.Context, txn *sql.Tx, localpart string,
) (err error) {
_, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err := txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return err
})
}
func (s *profilesStatements) selectProfileByLocalpart(

View file

@ -53,6 +53,8 @@ const deleteThreePIDSQL = "" +
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
type threepidStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
selectLocalpartForThreePIDStmt *sql.Stmt
selectThreePIDsForLocalpartStmt *sql.Stmt
insertThreePIDStmt *sql.Stmt
@ -60,6 +62,8 @@ type threepidStatements struct {
}
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(threepidSchema)
if err != nil {
return
@ -118,13 +122,18 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
func (s *threepidStatements) insertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
return
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
_, err := stmt.ExecContext(ctx, threepid, medium, localpart)
return err
})
}
func (s *threepidStatements) deleteThreePID(
ctx context.Context, threepid string, medium string) (err error) {
_, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium)
return
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt)
_, err := stmt.ExecContext(ctx, threepid, medium)
return err
})
}

View file

@ -74,6 +74,7 @@ const deleteDevicesSQL = "" +
type devicesStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertDeviceStmt *sql.Stmt
selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
@ -87,6 +88,7 @@ type devicesStatements struct {
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(devicesSchema)
if err != nil {
return
@ -128,13 +130,19 @@ func (s *devicesStatements) insertDevice(
) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
return nil, err
}
sessionID++
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
return err
}
sessionID++
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
return &api.Device{
@ -148,9 +156,11 @@ func (s *devicesStatements) insertDevice(
func (s *devicesStatements) deleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
_, err := stmt.ExecContext(ctx, id, localpart)
return err
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
_, err := stmt.ExecContext(ctx, id, localpart)
return err
})
}
func (s *devicesStatements) deleteDevices(
@ -161,31 +171,37 @@ func (s *devicesStatements) deleteDevices(
if err != nil {
return err
}
stmt := sqlutil.TxStmt(txn, prep)
params := make([]interface{}, len(devices)+1)
params[0] = localpart
for i, v := range devices {
params[i+1] = v
}
params = append(params, params...)
_, err = stmt.ExecContext(ctx, params...)
return err
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, prep)
params := make([]interface{}, len(devices)+1)
params[0] = localpart
for i, v := range devices {
params[i+1] = v
}
params = append(params, params...)
_, err = stmt.ExecContext(ctx, params...)
return err
})
}
func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart)
return err
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart)
return err
})
}
func (s *devicesStatements) updateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
return err
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
return err
})
}
func (s *devicesStatements) selectDeviceByToken(