Process inbound device list updates from federation (#1240)

* Add InputDeviceListUpdate

* Unbreak unit tests

* Process inbound device list updates from federation

- Persist the keys in the keyserver and produce key changes
- Does not currently fetch keys from the remote server if the prev IDs are missing

* Linting
This commit is contained in:
Kegsay 2020-08-05 13:41:16 +01:00 committed by GitHub
parent 15dc1f4d03
commit 642f9cb964
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 220 additions and 21 deletions

View file

@ -19,6 +19,7 @@ import (
"database/sql"
"time"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
@ -56,12 +57,16 @@ const selectBatchDeviceKeysSQL = "" +
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
const countStreamIDsForUserSQL = "" +
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
type deviceKeysStatements struct {
db *sql.DB
upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
countStreamIDsForUserStmt *sql.Stmt
}
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@ -84,6 +89,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err
}
if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil {
return nil, err
}
return s, nil
}
@ -115,6 +123,19 @@ func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn
return
}
func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
// nullable if there are no results
var count sql.NullInt32
err := s.countStreamIDsForUserStmt.QueryRowContext(ctx, userID, pq.Int64Array(streamIDs)).Scan(&count)
if err != nil {
return 0, err
}
if count.Valid {
return int(count.Int32), nil
}
return 0, nil
}
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
for _, key := range keys {
now := time.Now().Unix()