mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-07-31 21:32:46 +00:00
Merge both user API databases into one (#2186)
* Merge user API databases into one * Remove DeviceDatabase from config * Fix tests * Try that again * Clean up keyserver device keys when the devices no longer exist in the user API * Tweak ordering * Fix UserExists flag, device check * Allow including empty entries so we can clean them up * Remove logging
This commit is contained in:
parent
0a7dea4450
commit
153bfbbea5
76 changed files with 727 additions and 899 deletions
|
@ -53,7 +53,7 @@ type Database interface {
|
|||
|
||||
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
|
||||
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
|
||||
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
|
||||
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
|
||||
|
||||
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
|
||||
// cross-signing signatures relating to that device.
|
||||
|
|
|
@ -56,6 +56,9 @@ const selectDeviceKeysSQL = "" +
|
|||
const selectBatchDeviceKeysSQL = "" +
|
||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
||||
|
||||
const selectBatchDeviceKeysWithEmptiesSQL = "" +
|
||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
const selectMaxStreamForUserSQL = "" +
|
||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
|
@ -69,14 +72,15 @@ const deleteAllDeviceKeysSQL = "" +
|
|||
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
type deviceKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertDeviceKeysStmt *sql.Stmt
|
||||
selectDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysStmt *sql.Stmt
|
||||
selectMaxStreamForUserStmt *sql.Stmt
|
||||
countStreamIDsForUserStmt *sql.Stmt
|
||||
deleteDeviceKeysStmt *sql.Stmt
|
||||
deleteAllDeviceKeysStmt *sql.Stmt
|
||||
db *sql.DB
|
||||
upsertDeviceKeysStmt *sql.Stmt
|
||||
selectDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
|
||||
selectMaxStreamForUserStmt *sql.Stmt
|
||||
countStreamIDsForUserStmt *sql.Stmt
|
||||
deleteDeviceKeysStmt *sql.Stmt
|
||||
deleteAllDeviceKeysStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||
|
@ -96,6 +100,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
|||
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -180,8 +187,14 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
||||
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
|
||||
var stmt *sql.Stmt
|
||||
if includeEmpty {
|
||||
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
|
||||
} else {
|
||||
stmt = s.selectBatchDeviceKeysStmt
|
||||
}
|
||||
rows, err := stmt.QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -108,8 +108,8 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe
|
|||
})
|
||||
}
|
||||
|
||||
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
||||
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
|
||||
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
|
||||
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty)
|
||||
}
|
||||
|
||||
func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
|
||||
|
|
|
@ -52,6 +52,9 @@ const selectDeviceKeysSQL = "" +
|
|||
const selectBatchDeviceKeysSQL = "" +
|
||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
||||
|
||||
const selectBatchDeviceKeysWithEmptiesSQL = "" +
|
||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
const selectMaxStreamForUserSQL = "" +
|
||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
|
@ -65,13 +68,14 @@ const deleteAllDeviceKeysSQL = "" +
|
|||
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
type deviceKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertDeviceKeysStmt *sql.Stmt
|
||||
selectDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysStmt *sql.Stmt
|
||||
selectMaxStreamForUserStmt *sql.Stmt
|
||||
deleteDeviceKeysStmt *sql.Stmt
|
||||
deleteAllDeviceKeysStmt *sql.Stmt
|
||||
db *sql.DB
|
||||
upsertDeviceKeysStmt *sql.Stmt
|
||||
selectDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
|
||||
selectMaxStreamForUserStmt *sql.Stmt
|
||||
deleteDeviceKeysStmt *sql.Stmt
|
||||
deleteAllDeviceKeysStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||
|
@ -91,6 +95,9 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
|||
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -113,12 +120,18 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
|
||||
deviceIDMap := make(map[string]bool)
|
||||
for _, d := range deviceIDs {
|
||||
deviceIDMap[d] = true
|
||||
}
|
||||
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
||||
var stmt *sql.Stmt
|
||||
if includeEmpty {
|
||||
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
|
||||
} else {
|
||||
stmt = s.selectBatchDeviceKeysStmt
|
||||
}
|
||||
rows, err := stmt.QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -173,7 +173,7 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
|||
}
|
||||
|
||||
// Querying for device keys returns the latest stream IDs
|
||||
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"})
|
||||
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("DeviceKeysForUser returned error: %s", err)
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ type DeviceKeys interface {
|
|||
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
|
||||
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
|
||||
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
|
||||
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
|
||||
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
|
||||
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
|
||||
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue