Component-wide TransactionWriters (#1290)

* Offset updates take place using TransactionWriter

* Refactor TransactionWriter in current state server

* Refactor TransactionWriter in federation sender

* Refactor TransactionWriter in key server

* Refactor TransactionWriter in media API

* Refactor TransactionWriter in server key API

* Refactor TransactionWriter in sync API

* Refactor TransactionWriter in user API

* Fix deadlocking Sync API tests

* Un-deadlock device database

* Fix appservice API

* Rename TransactionWriters to Writers

* Move writers up a layer in sync API

* Document sqlutil.Writer interface

* Add note to Writer documentation
This commit is contained in:
Neil Alexander 2020-08-21 10:42:08 +01:00 committed by GitHub
parent 5aaf32bbed
commit 9d53351dc2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
56 changed files with 483 additions and 483 deletions

View file

@ -34,7 +34,8 @@ import (
// Database represents an account database
type Database struct {
db *sql.DB
db *sql.DB
writer sqlutil.Writer
sqlutil.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
@ -49,27 +50,27 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil {
return nil, err
}
partitions := sqlutil.PartitionOffsetStatements{}
if err = partitions.Prepare(db, "account"); err != nil {
d := &Database{
serverName: serverName,
db: db,
writer: sqlutil.NewDummyWriter(),
}
if err = d.PartitionOffsetStatements.Prepare(db, d.writer, "account"); err != nil {
return nil, err
}
a := accountsStatements{}
if err = a.prepare(db, serverName); err != nil {
if err = d.accounts.prepare(db, serverName); err != nil {
return nil, err
}
p := profilesStatements{}
if err = p.prepare(db); err != nil {
if err = d.profiles.prepare(db); err != nil {
return nil, err
}
ac := accountDataStatements{}
if err = ac.prepare(db); err != nil {
if err = d.accountDatas.prepare(db); err != nil {
return nil, err
}
t := threepidStatements{}
if err = t.prepare(db); err != nil {
if err = d.threepids.prepare(db); err != nil {
return nil, err
}
return &Database{db, partitions, a, p, ac, t, serverName}, nil
return d, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.

View file

@ -51,15 +51,15 @@ const selectAccountDataByTypeSQL = "" +
type accountDataStatements struct {
db *sql.DB
writer sqlutil.TransactionWriter
writer sqlutil.Writer
insertAccountDataStmt *sql.Stmt
selectAccountDataStmt *sql.Stmt
selectAccountDataByTypeStmt *sql.Stmt
}
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
func (s *accountDataStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
s.writer = writer
_, err = db.Exec(accountDataSchema)
if err != nil {
return

View file

@ -59,7 +59,7 @@ const selectNewNumericLocalpartSQL = "" +
type accountsStatements struct {
db *sql.DB
writer sqlutil.TransactionWriter
writer sqlutil.Writer
insertAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
@ -67,9 +67,9 @@ type accountsStatements struct {
serverName gomatrixserverlib.ServerName
}
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
func (s *accountsStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
s.writer = writer
_, err = db.Exec(accountsSchema)
if err != nil {
return

View file

@ -53,7 +53,7 @@ const selectProfilesBySearchSQL = "" +
type profilesStatements struct {
db *sql.DB
writer sqlutil.TransactionWriter
writer sqlutil.Writer
insertProfileStmt *sql.Stmt
selectProfileByLocalpartStmt *sql.Stmt
setAvatarURLStmt *sql.Stmt
@ -61,9 +61,9 @@ type profilesStatements struct {
selectProfilesBySearchStmt *sql.Stmt
}
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
func (s *profilesStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
s.writer = writer
_, err = db.Exec(profilesSchema)
if err != nil {
return

View file

@ -33,7 +33,9 @@ import (
// Database represents an account database
type Database struct {
db *sql.DB
db *sql.DB
writer sqlutil.Writer
sqlutil.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
@ -53,35 +55,28 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil {
return nil, err
}
d := &Database{
serverName: serverName,
db: db,
writer: sqlutil.NewExclusiveWriter(),
}
partitions := sqlutil.PartitionOffsetStatements{}
if err = partitions.Prepare(db, "account"); err != nil {
if err = partitions.Prepare(db, d.writer, "account"); err != nil {
return nil, err
}
a := accountsStatements{}
if err = a.prepare(db, serverName); err != nil {
if err = d.accounts.prepare(db, d.writer, serverName); err != nil {
return nil, err
}
p := profilesStatements{}
if err = p.prepare(db); err != nil {
if err = d.profiles.prepare(db, d.writer); err != nil {
return nil, err
}
ac := accountDataStatements{}
if err = ac.prepare(db); err != nil {
if err = d.accountDatas.prepare(db, d.writer); err != nil {
return nil, err
}
t := threepidStatements{}
if err = t.prepare(db); err != nil {
if err = d.threepids.prepare(db, d.writer); err != nil {
return nil, err
}
return &Database{
db: db,
PartitionOffsetStatements: partitions,
accounts: a,
profiles: p,
accountDatas: ac,
threepids: t,
serverName: serverName,
}, nil
return d, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.

View file

@ -54,16 +54,16 @@ const deleteThreePIDSQL = "" +
type threepidStatements struct {
db *sql.DB
writer sqlutil.TransactionWriter
writer sqlutil.Writer
selectLocalpartForThreePIDStmt *sql.Stmt
selectThreePIDsForLocalpartStmt *sql.Stmt
insertThreePIDStmt *sql.Stmt
deleteThreePIDStmt *sql.Stmt
}
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
func (s *threepidStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
s.writer = writer
_, err = db.Exec(threepidSchema)
if err != nil {
return

View file

@ -78,7 +78,7 @@ const selectDevicesByIDSQL = "" +
type devicesStatements struct {
db *sql.DB
writer sqlutil.TransactionWriter
writer sqlutil.Writer
insertDeviceStmt *sql.Stmt
selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
@ -91,9 +91,9 @@ type devicesStatements struct {
serverName gomatrixserverlib.ServerName
}
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
s.writer = writer
_, err = db.Exec(devicesSchema)
if err != nil {
return
@ -138,19 +138,13 @@ func (s *devicesStatements) insertDevice(
) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
return err
}
sessionID++
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
return err
}
return nil
})
if err != nil {
countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
return nil, err
}
sessionID++
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
return nil, err
}
return &api.Device{
@ -164,11 +158,9 @@ func (s *devicesStatements) insertDevice(
func (s *devicesStatements) deleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string,
) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
_, err := stmt.ExecContext(ctx, id, localpart)
return err
})
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
_, err := stmt.ExecContext(ctx, id, localpart)
return err
}
func (s *devicesStatements) deleteDevices(
@ -179,36 +171,30 @@ func (s *devicesStatements) deleteDevices(
if err != nil {
return err
}
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, prep)
params := make([]interface{}, len(devices)+1)
params[0] = localpart
for i, v := range devices {
params[i+1] = v
}
_, err = stmt.ExecContext(ctx, params...)
return err
})
stmt := sqlutil.TxStmt(txn, prep)
params := make([]interface{}, len(devices)+1)
params[0] = localpart
for i, v := range devices {
params[i+1] = v
}
_, err = stmt.ExecContext(ctx, params...)
return err
}
func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart)
return err
})
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart)
return err
}
func (s *devicesStatements) updateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
return err
})
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
return err
}
func (s *devicesStatements) selectDeviceByToken(

View file

@ -34,6 +34,7 @@ var deviceIDByteLength = 6
// Database represents a device database.
type Database struct {
db *sql.DB
writer sqlutil.Writer
devices devicesStatements
}
@ -43,11 +44,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil {
return nil, err
}
writer := sqlutil.NewExclusiveWriter()
d := devicesStatements{}
if err = d.prepare(db, serverName); err != nil {
if err = d.prepare(db, writer, serverName); err != nil {
return nil, err
}
return &Database{db, d}, nil
return &Database{db, writer, d}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
@ -88,7 +90,7 @@ func (d *Database) CreateDevice(
displayName *string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
@ -108,7 +110,7 @@ func (d *Database) CreateDevice(
return
}
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
return err
@ -138,7 +140,7 @@ func generateDeviceID() (string, error) {
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
@ -150,7 +152,7 @@ func (d *Database) UpdateDevice(
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
@ -165,7 +167,7 @@ func (d *Database) RemoveDevice(
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
@ -179,7 +181,7 @@ func (d *Database) RemoveDevices(
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
return err
}