mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-08-02 22:22:46 +00:00
Various refactoring
This commit is contained in:
parent
ed4097825b
commit
62bcd5ad4b
13 changed files with 178 additions and 182 deletions
|
@ -24,6 +24,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var deviceKeysSchema = `
|
||||
|
@ -105,19 +106,23 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
|||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
for i, key := range keys {
|
||||
var keyJSONStr string
|
||||
for _, key := range keys {
|
||||
var streamID int
|
||||
var displayName sql.NullString
|
||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
|
||||
if key.DeviceKeys == nil {
|
||||
key.DeviceKeys = &gomatrixserverlib.DeviceKeys{}
|
||||
}
|
||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&key.DeviceKeys, &streamID, &displayName)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
// this will be '' when there is no device
|
||||
keys[i].KeyJSON = []byte(keyJSONStr)
|
||||
keys[i].StreamID = streamID
|
||||
key.StreamID = streamID
|
||||
if displayName.Valid {
|
||||
keys[i].DisplayName = displayName.String
|
||||
if key.DeviceKeys.Unsigned == nil {
|
||||
key.DeviceKeys.Unsigned = make(map[string]interface{})
|
||||
}
|
||||
key.DeviceKeys.Unsigned["device_display_name"] = displayName.String
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
@ -153,7 +158,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
|
|||
for _, key := range keys {
|
||||
now := time.Now().Unix()
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
|
||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
|
||||
ctx, key.UserID, key.DeviceID, now, key.DeviceKeys, key.StreamID, key.DisplayName,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -179,18 +184,21 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
|||
}
|
||||
var result []api.DeviceMessage
|
||||
for rows.Next() {
|
||||
var dk api.DeviceMessage
|
||||
dk := api.DeviceMessage{
|
||||
DeviceKeys: &gomatrixserverlib.DeviceKeys{},
|
||||
}
|
||||
dk.UserID = userID
|
||||
var keyJSON string
|
||||
var streamID int
|
||||
var displayName sql.NullString
|
||||
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
|
||||
if err := rows.Scan(&dk.DeviceID, &dk.DeviceKeys, &streamID, &displayName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dk.KeyJSON = []byte(keyJSON)
|
||||
dk.StreamID = streamID
|
||||
if displayName.Valid {
|
||||
dk.DisplayName = displayName.String
|
||||
if dk.DeviceKeys.Unsigned == nil {
|
||||
dk.DeviceKeys.Unsigned = make(map[string]interface{})
|
||||
}
|
||||
dk.DeviceKeys.Unsigned["device_display_name"] = displayName.String
|
||||
}
|
||||
// include the key if we want all keys (no device) or it was asked
|
||||
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var deviceKeysSchema = `
|
||||
|
@ -113,18 +114,21 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
|||
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
|
||||
var result []api.DeviceMessage
|
||||
for rows.Next() {
|
||||
var dk api.DeviceMessage
|
||||
dk := api.DeviceMessage{
|
||||
DeviceKeys: &gomatrixserverlib.DeviceKeys{},
|
||||
}
|
||||
dk.UserID = userID
|
||||
var keyJSON string
|
||||
var streamID int
|
||||
var displayName sql.NullString
|
||||
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
|
||||
if err := rows.Scan(&dk.DeviceID, &dk.DeviceKeys, &streamID, &displayName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dk.KeyJSON = []byte(keyJSON)
|
||||
dk.StreamID = streamID
|
||||
if displayName.Valid {
|
||||
dk.DisplayName = displayName.String
|
||||
if dk.DeviceKeys.Unsigned == nil {
|
||||
dk.DeviceKeys.Unsigned = make(map[string]interface{})
|
||||
}
|
||||
dk.DeviceKeys.Unsigned["device_display_name"] = displayName.String
|
||||
}
|
||||
// include the key if we want all keys (no device) or it was asked
|
||||
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
||||
|
@ -135,19 +139,23 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
|||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
for i, key := range keys {
|
||||
var keyJSONStr string
|
||||
for _, key := range keys {
|
||||
var streamID int
|
||||
var displayName sql.NullString
|
||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
|
||||
if key.DeviceKeys == nil {
|
||||
key.DeviceKeys = &gomatrixserverlib.DeviceKeys{}
|
||||
}
|
||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&key.DeviceKeys, &streamID, &displayName)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
// this will be '' when there is no device
|
||||
keys[i].KeyJSON = []byte(keyJSONStr)
|
||||
keys[i].StreamID = streamID
|
||||
key.StreamID = streamID
|
||||
if displayName.Valid {
|
||||
keys[i].DisplayName = displayName.String
|
||||
if key.DeviceKeys.Unsigned == nil {
|
||||
key.DeviceKeys.Unsigned = make(map[string]interface{})
|
||||
}
|
||||
key.DeviceKeys.Unsigned["device_display_name"] = displayName.String
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
@ -189,7 +197,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
|
|||
for _, key := range keys {
|
||||
now := time.Now().Unix()
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
|
||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
|
||||
ctx, key.UserID, key.DeviceID, now, key.DeviceKeys, key.StreamID, key.DisplayName(),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/Shopify/sarama"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var ctx = context.Background()
|
||||
|
@ -105,28 +106,34 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
|||
bob := "@bob:TestDeviceKeysStreamIDGeneration"
|
||||
msgs := []api.DeviceMessage{
|
||||
{
|
||||
DeviceKeys: api.DeviceKeys{
|
||||
DeviceID: "AAA",
|
||||
UserID: alice,
|
||||
KeyJSON: []byte(`{"key":"v1"}`),
|
||||
DeviceKeys: &gomatrixserverlib.DeviceKeys{
|
||||
RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{
|
||||
DeviceID: "AAA",
|
||||
UserID: alice,
|
||||
Algorithms: []string{"v1"},
|
||||
},
|
||||
},
|
||||
// StreamID: 1
|
||||
},
|
||||
{
|
||||
DeviceKeys: api.DeviceKeys{
|
||||
DeviceID: "AAA",
|
||||
UserID: bob,
|
||||
KeyJSON: []byte(`{"key":"v1"}`),
|
||||
DeviceKeys: &gomatrixserverlib.DeviceKeys{
|
||||
RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{
|
||||
DeviceID: "AAA",
|
||||
UserID: bob,
|
||||
Algorithms: []string{"v1"},
|
||||
},
|
||||
},
|
||||
// StreamID: 1 as this is a different user
|
||||
},
|
||||
{
|
||||
DeviceKeys: api.DeviceKeys{
|
||||
DeviceID: "another_device",
|
||||
UserID: alice,
|
||||
KeyJSON: []byte(`{"key":"v1"}`),
|
||||
DeviceKeys: &gomatrixserverlib.DeviceKeys{
|
||||
RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{
|
||||
DeviceID: "another_device",
|
||||
UserID: alice,
|
||||
Algorithms: []string{"v2"},
|
||||
},
|
||||
},
|
||||
// StreamID: 2 as this is a 2nd device key
|
||||
// StreamID: 2 as this is a 2nd key
|
||||
},
|
||||
}
|
||||
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
|
||||
|
@ -143,10 +150,12 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
|||
// updating a device sets the next stream ID for that user
|
||||
msgs = []api.DeviceMessage{
|
||||
{
|
||||
DeviceKeys: api.DeviceKeys{
|
||||
DeviceID: "AAA",
|
||||
UserID: alice,
|
||||
KeyJSON: []byte(`{"key":"v2"}`),
|
||||
DeviceKeys: &gomatrixserverlib.DeviceKeys{
|
||||
RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{
|
||||
DeviceID: "AAA",
|
||||
UserID: alice,
|
||||
Algorithms: []string{"v3"},
|
||||
},
|
||||
},
|
||||
// StreamID: 3
|
||||
},
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue