Various refactoring

This commit is contained in:
Neil Alexander 2021-07-30 16:27:55 +01:00
parent ed4097825b
commit 62bcd5ad4b
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
13 changed files with 178 additions and 182 deletions

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
var deviceKeysSchema = `
@ -105,19 +106,23 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
}
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys {
var keyJSONStr string
for _, key := range keys {
var streamID int
var displayName sql.NullString
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
if key.DeviceKeys == nil {
key.DeviceKeys = &gomatrixserverlib.DeviceKeys{}
}
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&key.DeviceKeys, &streamID, &displayName)
if err != nil && err != sql.ErrNoRows {
return err
}
// this will be '' when there is no device
keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID
key.StreamID = streamID
if displayName.Valid {
keys[i].DisplayName = displayName.String
if key.DeviceKeys.Unsigned == nil {
key.DeviceKeys.Unsigned = make(map[string]interface{})
}
key.DeviceKeys.Unsigned["device_display_name"] = displayName.String
}
}
return nil
@ -153,7 +158,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
for _, key := range keys {
now := time.Now().Unix()
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
ctx, key.UserID, key.DeviceID, now, key.DeviceKeys, key.StreamID, key.DisplayName,
)
if err != nil {
return err
@ -179,18 +184,21 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
}
var result []api.DeviceMessage
for rows.Next() {
var dk api.DeviceMessage
dk := api.DeviceMessage{
DeviceKeys: &gomatrixserverlib.DeviceKeys{},
}
dk.UserID = userID
var keyJSON string
var streamID int
var displayName sql.NullString
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
if err := rows.Scan(&dk.DeviceID, &dk.DeviceKeys, &streamID, &displayName); err != nil {
return nil, err
}
dk.KeyJSON = []byte(keyJSON)
dk.StreamID = streamID
if displayName.Valid {
dk.DisplayName = displayName.String
if dk.DeviceKeys.Unsigned == nil {
dk.DeviceKeys.Unsigned = make(map[string]interface{})
}
dk.DeviceKeys.Unsigned["device_display_name"] = displayName.String
}
// include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
var deviceKeysSchema = `
@ -113,18 +114,21 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
var result []api.DeviceMessage
for rows.Next() {
var dk api.DeviceMessage
dk := api.DeviceMessage{
DeviceKeys: &gomatrixserverlib.DeviceKeys{},
}
dk.UserID = userID
var keyJSON string
var streamID int
var displayName sql.NullString
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
if err := rows.Scan(&dk.DeviceID, &dk.DeviceKeys, &streamID, &displayName); err != nil {
return nil, err
}
dk.KeyJSON = []byte(keyJSON)
dk.StreamID = streamID
if displayName.Valid {
dk.DisplayName = displayName.String
if dk.DeviceKeys.Unsigned == nil {
dk.DeviceKeys.Unsigned = make(map[string]interface{})
}
dk.DeviceKeys.Unsigned["device_display_name"] = displayName.String
}
// include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
@ -135,19 +139,23 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
}
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys {
var keyJSONStr string
for _, key := range keys {
var streamID int
var displayName sql.NullString
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
if key.DeviceKeys == nil {
key.DeviceKeys = &gomatrixserverlib.DeviceKeys{}
}
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&key.DeviceKeys, &streamID, &displayName)
if err != nil && err != sql.ErrNoRows {
return err
}
// this will be '' when there is no device
keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID
key.StreamID = streamID
if displayName.Valid {
keys[i].DisplayName = displayName.String
if key.DeviceKeys.Unsigned == nil {
key.DeviceKeys.Unsigned = make(map[string]interface{})
}
key.DeviceKeys.Unsigned["device_display_name"] = displayName.String
}
}
return nil
@ -189,7 +197,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
for _, key := range keys {
now := time.Now().Unix()
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
ctx, key.UserID, key.DeviceID, now, key.DeviceKeys, key.StreamID, key.DisplayName(),
)
if err != nil {
return err

View file

@ -12,6 +12,7 @@ import (
"github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
var ctx = context.Background()
@ -105,28 +106,34 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
bob := "@bob:TestDeviceKeysStreamIDGeneration"
msgs := []api.DeviceMessage{
{
DeviceKeys: api.DeviceKeys{
DeviceID: "AAA",
UserID: alice,
KeyJSON: []byte(`{"key":"v1"}`),
DeviceKeys: &gomatrixserverlib.DeviceKeys{
RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{
DeviceID: "AAA",
UserID: alice,
Algorithms: []string{"v1"},
},
},
// StreamID: 1
},
{
DeviceKeys: api.DeviceKeys{
DeviceID: "AAA",
UserID: bob,
KeyJSON: []byte(`{"key":"v1"}`),
DeviceKeys: &gomatrixserverlib.DeviceKeys{
RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{
DeviceID: "AAA",
UserID: bob,
Algorithms: []string{"v1"},
},
},
// StreamID: 1 as this is a different user
},
{
DeviceKeys: api.DeviceKeys{
DeviceID: "another_device",
UserID: alice,
KeyJSON: []byte(`{"key":"v1"}`),
DeviceKeys: &gomatrixserverlib.DeviceKeys{
RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{
DeviceID: "another_device",
UserID: alice,
Algorithms: []string{"v2"},
},
},
// StreamID: 2 as this is a 2nd device key
// StreamID: 2 as this is a 2nd key
},
}
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
@ -143,10 +150,12 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
// updating a device sets the next stream ID for that user
msgs = []api.DeviceMessage{
{
DeviceKeys: api.DeviceKeys{
DeviceID: "AAA",
UserID: alice,
KeyJSON: []byte(`{"key":"v2"}`),
DeviceKeys: &gomatrixserverlib.DeviceKeys{
RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{
DeviceID: "AAA",
UserID: alice,
Algorithms: []string{"v3"},
},
},
// StreamID: 3
},