Use Writer in shared package (#1296)

This commit is contained in:
Neil Alexander 2020-08-25 10:29:45 +01:00 committed by GitHub
parent 3b14119aff
commit 720ddce0a8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 98 additions and 109 deletions

View file

@ -27,6 +27,7 @@ import (
type Database struct {
DB *sql.DB
Writer sqlutil.Writer
OneTimeKeysTable tables.OneTimeKeys
DeviceKeysTable tables.DeviceKeys
KeyChangesTable tables.KeyChanges
@ -37,8 +38,12 @@ func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID str
return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms)
}
func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
return d.OneTimeKeysTable.InsertOneTimeKeys(ctx, keys)
func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) {
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys)
return nil
})
return
}
func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
@ -62,7 +67,7 @@ func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []i
}
func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error {
return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for _, userID := range clearUserIDs {
err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID)
if err != nil {
@ -79,7 +84,7 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe
for _, k := range keys {
userIDToStreamID[k.UserID] = 0
}
return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for userID := range userIDToStreamID {
streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID)
if err != nil {
@ -104,7 +109,7 @@ func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceI
func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
var result []api.OneTimeKeys
err := sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for userID, deviceToAlgo := range userToDeviceToAlgorithm {
for deviceID, algo := range deviceToAlgo {
keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo)
@ -126,7 +131,9 @@ func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[st
}
func (d *Database) StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
return d.KeyChangesTable.InsertKeyChange(ctx, partition, offset, userID)
return d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
return d.KeyChangesTable.InsertKeyChange(ctx, partition, offset, userID)
})
}
func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) {
@ -141,5 +148,7 @@ func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserve
// MarkDeviceListStale sets the stale bit for this user to isStale.
func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale)
return d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale)
})
}