From a6bb3fd0ac37eda558c6f873d495c258689b6a3b Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 28 Jul 2021 15:52:52 +0100 Subject: [PATCH] Retrieve cross-signing keys sorta --- keyserver/api/api.go | 6 +++ keyserver/internal/internal.go | 46 +++++++++++++++++++ keyserver/storage/interface.go | 2 + .../postgres/cross_signing_keys_table.go | 36 ++++++++++++++- keyserver/storage/shared/storage.go | 5 ++ .../sqlite3/cross_signing_keys_table.go | 36 ++++++++++++++- keyserver/storage/tables/interface.go | 4 +- 7 files changed, 132 insertions(+), 3 deletions(-) diff --git a/keyserver/api/api.go b/keyserver/api/api.go index b0cdd3b7..fd35d26d 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -40,6 +40,8 @@ type KeyInternalAPI interface { QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) } +type CrossSigningKeyMap map[gomatrixserverlib.CrossSigningKeyPurpose]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes + // KeyError is returned if there was a problem performing/querying the server type KeyError struct { Err string `json:"error"` @@ -182,6 +184,10 @@ type QueryKeysResponse struct { Failures map[string]interface{} // Map of user_id to device_id to device_key DeviceKeys map[string]map[string]json.RawMessage + // Maps of user_id to cross signing key + MasterKeys map[string]gomatrixserverlib.CrossSigningKey + SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey + UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey // Set if there was a fatal error processing this query Error *KeyError } diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 5a0b3190..4466b635 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -282,6 +282,12 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques return // nothing to query } + // get cross-signing keys from the database + if err := a.crossSigningKeys(ctx, req, res); err != nil { + // TODO: handle this + util.GetLogger(ctx).WithError(err).Error("Failed to retrieve cross-signing keys") + } + // perform key queries for remote devices a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) } @@ -417,6 +423,46 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( } +func (a *KeyInternalAPI) crossSigningKeys( + ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse, +) error { + for userID := range req.UserToDevices { + keys, err := a.DB.CrossSigningKeysForUser(ctx, userID) + if err != nil { + return fmt.Errorf("a.DB.CrossSigningKeysForUser (%q): %w", userID, err) + } + + for keyType, keysByType := range keys { + for keyID, keyData := range keysByType { + key := gomatrixserverlib.CrossSigningKey{ + UserID: userID, + Usage: []gomatrixserverlib.CrossSigningKeyPurpose{ + keyType, + }, + Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ + keyID: keyData, + }, + } + + // TODO: populate signatures + + switch keyType { + case gomatrixserverlib.CrossSigningKeyPurposeMaster: + res.MasterKeys[userID] = key + + case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning: + res.SelfSigningKeys[userID] = key + + case gomatrixserverlib.CrossSigningKeyPurposeUserSigning: + res.UserSigningKeys[userID] = key + } + } + } + } + + return nil +} + func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string, ) error { diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index e536dfb4..ccb132f1 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -76,4 +76,6 @@ type Database interface { // MarkDeviceListStale sets the stale bit for this user to isStale. MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error + + CrossSigningKeysForUser(ctx context.Context, userID string) (api.CrossSigningKeyMap, error) } diff --git a/keyserver/storage/postgres/cross_signing_keys_table.go b/keyserver/storage/postgres/cross_signing_keys_table.go index 51f4f8a3..93c31475 100644 --- a/keyserver/storage/postgres/cross_signing_keys_table.go +++ b/keyserver/storage/postgres/cross_signing_keys_table.go @@ -15,15 +15,20 @@ package postgres import ( + "context" "database/sql" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) var crossSigningKeysSchema = ` CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys ( user_id TEXT NOT NULL, key_type TEXT NOT NULL, + key_id TEXT NOT NULL, key_data TEXT NOT NULL, stream_id BIGINT NOT NULL ); @@ -31,8 +36,14 @@ CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys ( CREATE UNIQUE INDEX IF NOT EXISTS keyserver_cross_signing_keys_idx ON keyserver_cross_signing_keys(user_id, key_type, stream_id); ` +const selectCrossSigningKeysForUserSQL = "" + + "SELECT DISTINCT ON (user_id, key_type) key_type, key_id, key_data FROM keyserver_cross_signing_keys" + + " WHERE user_id = $1" + + " ORDER BY user_id, key_type, stream_id DESC" + type crossSigningKeysStatements struct { - db *sql.DB + db *sql.DB + selectCrossSigningKeysForUserStmt *sql.Stmt } func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { @@ -43,5 +54,28 @@ func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, erro if err != nil { return nil, err } + if s.selectCrossSigningKeysForUserStmt, err = db.Prepare(selectCrossSigningKeysForUserSQL); err != nil { + return nil, err + } return s, nil } + +func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( + ctx context.Context, userID string, +) (r api.CrossSigningKeyMap, err error) { + rows, err := s.selectCrossSigningKeysForUserStmt.QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed") + for rows.Next() { + var keyType gomatrixserverlib.CrossSigningKeyPurpose + var keyID gomatrixserverlib.KeyID + var keyData gomatrixserverlib.Base64Bytes + if err := rows.Scan(&keyType, &keyID, &keyData); err != nil { + return nil, err + } + r[keyType][keyID] = keyData + } + return +} diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 30b1e92e..a5e619f5 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -156,3 +156,8 @@ func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isSta return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale) }) } + +// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. +func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) (api.CrossSigningKeyMap, error) { + return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, userID) +} diff --git a/keyserver/storage/sqlite3/cross_signing_keys_table.go b/keyserver/storage/sqlite3/cross_signing_keys_table.go index bd8bbc60..0b047c6f 100644 --- a/keyserver/storage/sqlite3/cross_signing_keys_table.go +++ b/keyserver/storage/sqlite3/cross_signing_keys_table.go @@ -15,15 +15,20 @@ package sqlite3 import ( + "context" "database/sql" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) var crossSigningKeysSchema = ` CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys ( user_id TEXT NOT NULL, key_type TEXT NOT NULL, + key_id TEXT NOT NULL, key_data TEXT NOT NULL, stream_id BIGINT NOT NULL ); @@ -31,8 +36,14 @@ CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys ( CREATE UNIQUE INDEX IF NOT EXISTS keyserver_cross_signing_keys_idx ON keyserver_cross_signing_keys(user_id, key_type, stream_id); ` +const selectCrossSigningKeysForUserSQL = "" + + "SELECT key_type, key_id, key_data FROM " + + " (SELECT * FROM keyserver_cross_signing_keys WHERE user_id = $1 ORDER BY stream_id DESC)" + + " GROUP BY user_id, key_type" + type crossSigningKeysStatements struct { - db *sql.DB + db *sql.DB + selectCrossSigningKeysForUserStmt *sql.Stmt } func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { @@ -43,5 +54,28 @@ func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) if err != nil { return nil, err } + if s.selectCrossSigningKeysForUserStmt, err = db.Prepare(selectCrossSigningKeysForUserSQL); err != nil { + return nil, err + } return s, nil } + +func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( + ctx context.Context, userID string, +) (r api.CrossSigningKeyMap, err error) { + rows, err := s.selectCrossSigningKeysForUserStmt.QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed") + for rows.Next() { + var keyType gomatrixserverlib.CrossSigningKeyPurpose + var keyID gomatrixserverlib.KeyID + var keyData gomatrixserverlib.Base64Bytes + if err := rows.Scan(&keyType, &keyID, &keyData); err != nil { + return nil, err + } + r[keyType][keyID] = keyData + } + return +} diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 338d4a73..8d10313f 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -53,7 +53,9 @@ type StaleDeviceLists interface { SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) } -type CrossSigningKeys interface{} +type CrossSigningKeys interface { + SelectCrossSigningKeysForUser(ctx context.Context, userID string) (r api.CrossSigningKeyMap, err error) +} type CrossSigningSigs interface{}