Various refactoring

This commit is contained in:
Neil Alexander 2021-07-30 16:27:55 +01:00
parent ed4097825b
commit 62bcd5ad4b
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
13 changed files with 178 additions and 182 deletions

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
var deviceKeysSchema = `
@ -113,18 +114,21 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
var result []api.DeviceMessage
for rows.Next() {
var dk api.DeviceMessage
dk := api.DeviceMessage{
DeviceKeys: &gomatrixserverlib.DeviceKeys{},
}
dk.UserID = userID
var keyJSON string
var streamID int
var displayName sql.NullString
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
if err := rows.Scan(&dk.DeviceID, &dk.DeviceKeys, &streamID, &displayName); err != nil {
return nil, err
}
dk.KeyJSON = []byte(keyJSON)
dk.StreamID = streamID
if displayName.Valid {
dk.DisplayName = displayName.String
if dk.DeviceKeys.Unsigned == nil {
dk.DeviceKeys.Unsigned = make(map[string]interface{})
}
dk.DeviceKeys.Unsigned["device_display_name"] = displayName.String
}
// include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
@ -135,19 +139,23 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
}
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys {
var keyJSONStr string
for _, key := range keys {
var streamID int
var displayName sql.NullString
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
if key.DeviceKeys == nil {
key.DeviceKeys = &gomatrixserverlib.DeviceKeys{}
}
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&key.DeviceKeys, &streamID, &displayName)
if err != nil && err != sql.ErrNoRows {
return err
}
// this will be '' when there is no device
keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID
key.StreamID = streamID
if displayName.Valid {
keys[i].DisplayName = displayName.String
if key.DeviceKeys.Unsigned == nil {
key.DeviceKeys.Unsigned = make(map[string]interface{})
}
key.DeviceKeys.Unsigned["device_display_name"] = displayName.String
}
}
return nil
@ -189,7 +197,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
for _, key := range keys {
now := time.Now().Unix()
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
ctx, key.UserID, key.DeviceID, now, key.DeviceKeys, key.StreamID, key.DisplayName(),
)
if err != nil {
return err