From fba18e8b17cd8c8db9c72aabc53fa850d885bbec Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 29 Jul 2021 14:35:42 +0100 Subject: [PATCH] Include signatures in key query --- clientapi/routing/keys.go | 3 +- clientapi/routing/routing.go | 2 +- keyserver/api/api.go | 7 +++ keyserver/internal/cross_signing.go | 27 ++++++++++- keyserver/storage/interface.go | 1 + .../postgres/cross_signing_sigs_table.go | 47 +++++++++++++++++-- keyserver/storage/shared/storage.go | 5 ++ .../sqlite3/cross_signing_sigs_table.go | 47 +++++++++++++++++-- keyserver/storage/tables/interface.go | 4 +- 9 files changed, 128 insertions(+), 15 deletions(-) diff --git a/clientapi/routing/keys.go b/clientapi/routing/keys.go index 8b8fcc27..2d65ac35 100644 --- a/clientapi/routing/keys.go +++ b/clientapi/routing/keys.go @@ -100,7 +100,7 @@ func (r *queryKeysRequest) GetTimeout() time.Duration { return time.Duration(r.Timeout) * time.Millisecond } -func QueryKeys(req *http.Request, keyAPI api.KeyInternalAPI) util.JSONResponse { +func QueryKeys(req *http.Request, keyAPI api.KeyInternalAPI, device *userapi.Device) util.JSONResponse { var r queryKeysRequest resErr := httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { @@ -108,6 +108,7 @@ func QueryKeys(req *http.Request, keyAPI api.KeyInternalAPI) util.JSONResponse { } queryRes := api.QueryKeysResponse{} keyAPI.QueryKeys(req.Context(), &api.QueryKeysRequest{ + UserID: device.UserID, UserToDevices: r.DeviceKeys, Timeout: r.GetTimeout(), // TODO: Token? diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 25baec1a..874ebde6 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -1097,7 +1097,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/keys/query", httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return QueryKeys(req, keyAPI) + return QueryKeys(req, keyAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/keys/claim", diff --git a/keyserver/api/api.go b/keyserver/api/api.go index b7a0d7c8..878af47c 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -40,8 +40,12 @@ type KeyInternalAPI interface { QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) } +// Map of purpose -> public key type CrossSigningKeyMap map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.Base64Bytes +// Map of user ID -> key ID -> signature +type CrossSigningSigMap map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes + // KeyError is returned if there was a problem performing/querying the server type KeyError struct { Err string `json:"error"` @@ -180,6 +184,9 @@ type PerformUploadDeviceSignaturesResponse struct { } type QueryKeysRequest struct { + // The user ID asking for the keys, e.g. if from a client API request. + // Will not be populated if the key request came from federation. + UserID string // Maps user IDs to a list of devices UserToDevices map[string][]string Timeout time.Duration diff --git a/keyserver/internal/cross_signing.go b/keyserver/internal/cross_signing.go index 8d21d30b..4aff9bb9 100644 --- a/keyserver/internal/cross_signing.go +++ b/keyserver/internal/cross_signing.go @@ -177,17 +177,40 @@ func (a *KeyInternalAPI) crossSigningKeys( for keyType, keyData := range keys { b64 := keyData.Encode() + keyID := gomatrixserverlib.KeyID("ed25519:" + b64) key := gomatrixserverlib.CrossSigningKey{ UserID: userID, Usage: []gomatrixserverlib.CrossSigningKeyPurpose{ keyType, }, Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ - gomatrixserverlib.KeyID("ed25519:" + b64): keyData, + keyID: keyData, }, } - // TODO: populate signatures + sigs, err := a.DB.CrossSigningSigsForTarget(ctx, userID, keyID) + if err != nil { + logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", userID, keyID) + return fmt.Errorf("a.DB.CrossSigningSigsForTarget (%q key %q): %w", userID, keyID, err) + } + + appendSignature := func(originUserID string, originKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) { + if _, ok := key.Signatures[originUserID]; !ok { + key.Signatures[originUserID] = make(map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes) + } + key.Signatures[originUserID][originKeyID] = signature + } + + for originUserID, forOrigin := range sigs { + for originKeyID, signature := range forOrigin { + switch { + case req.UserID != "" && originUserID == req.UserID: + appendSignature(originUserID, originKeyID, signature) + case originUserID == userID: + appendSignature(originUserID, originKeyID, signature) + } + } + } switch keyType { case gomatrixserverlib.CrossSigningKeyPurposeMaster: diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 4382960a..e0eeefcf 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -78,5 +78,6 @@ type Database interface { MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error CrossSigningKeysForUser(ctx context.Context, userID string) (api.CrossSigningKeyMap, error) + CrossSigningSigsForTarget(ctx context.Context, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (api.CrossSigningSigMap, error) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap api.CrossSigningKeyMap, streamID int64) error } diff --git a/keyserver/storage/postgres/cross_signing_sigs_table.go b/keyserver/storage/postgres/cross_signing_sigs_table.go index be10ccf8..40d6d96f 100644 --- a/keyserver/storage/postgres/cross_signing_sigs_table.go +++ b/keyserver/storage/postgres/cross_signing_sigs_table.go @@ -15,25 +15,35 @@ package postgres import ( + "context" "database/sql" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) var crossSigningSigsSchema = ` CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs ( - user_id TEXT NOT NULL, - key_id TEXT NOT NULL, + origin_user_id TEXT NOT NULL, + origin_key_id TEXT NOT NULL, target_user_id TEXT NOT NULL, - target_device_id TEXT NOT NULL, + target_key_id TEXT NOT NULL, signature TEXT NOT NULL ); -CREATE UNIQUE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs(user_id, target_user_id, target_device_id); +CREATE UNIQUE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs(origin_user_id, target_user_id, target_key_id); ` +const selectCrossSigningSigsForTargetSQL = "" + + "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + + " WHERE target_user_id = $1 AND target_key_id = $2" + type crossSigningSigsStatements struct { - db *sql.DB + db *sql.DB + selectCrossSigningSigsForTargetStmt *sql.Stmt } func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) { @@ -44,5 +54,32 @@ func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, erro if err != nil { return nil, err } + if s.selectCrossSigningSigsForTargetStmt, err = db.Prepare(selectCrossSigningSigsForTargetSQL); err != nil { + return nil, err + } return s, nil } + +func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID, +) (r api.CrossSigningSigMap, err error) { + rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, targetUserID, targetKeyID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForTargetStmt: rows.close() failed") + r = api.CrossSigningSigMap{} + for rows.Next() { + var userID string + var keyID gomatrixserverlib.KeyID + var signature gomatrixserverlib.Base64Bytes + if err := rows.Scan(&userID, &keyID, &signature); err != nil { + return nil, err + } + if _, ok := r[userID]; !ok { + r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + r[userID][keyID] = signature + } + return +} diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index e8e73135..6a63f3c2 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -163,6 +163,11 @@ func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) ( return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) } +// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any. +func (d *Database) CrossSigningSigsForTarget(ctx context.Context, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (api.CrossSigningSigMap, error) { + return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, targetUserID, targetKeyID) +} + // StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user. func (d *Database) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap api.CrossSigningKeyMap, streamID int64) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { diff --git a/keyserver/storage/sqlite3/cross_signing_sigs_table.go b/keyserver/storage/sqlite3/cross_signing_sigs_table.go index db8e1912..b6ffc3c9 100644 --- a/keyserver/storage/sqlite3/cross_signing_sigs_table.go +++ b/keyserver/storage/sqlite3/cross_signing_sigs_table.go @@ -15,25 +15,35 @@ package sqlite3 import ( + "context" "database/sql" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) var crossSigningSigsSchema = ` CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs ( - user_id TEXT NOT NULL, - key_id TEXT NOT NULL, + origin_user_id TEXT NOT NULL, + origin_key_id TEXT NOT NULL, target_user_id TEXT NOT NULL, - target_device_id TEXT NOT NULL, + target_key_id TEXT NOT NULL, signature TEXT NOT NULL ); -CREATE UNIQUE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs(user_id, target_user_id, target_device_id); +CREATE UNIQUE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs(origin_user_id, target_user_id, target_key_id); ` +const selectCrossSigningSigsForTargetSQL = "" + + "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + + " WHERE target_user_id = $1 AND target_key_id = $2" + type crossSigningSigsStatements struct { - db *sql.DB + db *sql.DB + selectCrossSigningSigsForTargetStmt *sql.Stmt } func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) { @@ -44,5 +54,32 @@ func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) if err != nil { return nil, err } + if s.selectCrossSigningSigsForTargetStmt, err = db.Prepare(selectCrossSigningSigsForTargetSQL); err != nil { + return nil, err + } return s, nil } + +func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID, +) (r api.CrossSigningSigMap, err error) { + rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, targetUserID, targetKeyID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForTargetStmt: rows.close() failed") + r = api.CrossSigningSigMap{} + for rows.Next() { + var userID string + var keyID gomatrixserverlib.KeyID + var signature gomatrixserverlib.Base64Bytes + if err := rows.Scan(&userID, &keyID, &signature); err != nil { + return nil, err + } + if _, ok := r[userID]; !ok { + r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + r[userID][keyID] = signature + } + return +} diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 1ea3b98c..516de652 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -58,6 +58,8 @@ type CrossSigningKeys interface { InsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes, streamID int64) error } -type CrossSigningSigs interface{} +type CrossSigningSigs interface { + SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r api.CrossSigningSigMap, err error) +} type CrossSigningStreams interface{}