diff --git a/keyserver/internal/cross_signing.go b/keyserver/internal/cross_signing.go index 40715b67..b08a6e45 100644 --- a/keyserver/internal/cross_signing.go +++ b/keyserver/internal/cross_signing.go @@ -208,8 +208,11 @@ func (a *KeyInternalAPI) crossSigningKeys( for originKeyID, signature := range forOrigin { switch { case req.UserID != "" && originUserID == req.UserID: + // Include signatures that we created appendSignature(originUserID, originKeyID, signature) case originUserID == userID: + // Include signatures that were created by the person whose key + // we are processing appendSignature(originUserID, originKeyID, signature) } } diff --git a/keyserver/storage/postgres/cross_signing_sigs_table.go b/keyserver/storage/postgres/cross_signing_sigs_table.go index 40d6d96f..a7a60c55 100644 --- a/keyserver/storage/postgres/cross_signing_sigs_table.go +++ b/keyserver/storage/postgres/cross_signing_sigs_table.go @@ -17,6 +17,7 @@ package postgres import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -41,9 +42,14 @@ const selectCrossSigningSigsForTargetSQL = "" + "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + " WHERE target_user_id = $1 AND target_key_id = $2" +const insertCrossSigningSigsForTargetSQL = "" + + "INSERT INTO keyserver_cross_signing_keys (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + + " VALUES($1, $2, $3, $4, $5)" + type crossSigningSigsStatements struct { db *sql.DB selectCrossSigningSigsForTargetStmt *sql.Stmt + insertCrossSigningSigsForTargetStmt *sql.Stmt } func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) { @@ -57,6 +63,9 @@ func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, erro if s.selectCrossSigningSigsForTargetStmt, err = db.Prepare(selectCrossSigningSigsForTargetSQL); err != nil { return nil, err } + if s.insertCrossSigningSigsForTargetStmt, err = db.Prepare(insertCrossSigningSigsForTargetSQL); err != nil { + return nil, err + } return s, nil } @@ -83,3 +92,15 @@ func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( } return } + +func (s *crossSigningSigsStatements) InsertCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, + originUserID string, originKeyID gomatrixserverlib.KeyID, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, + signature gomatrixserverlib.Base64Bytes, +) error { + if _, err := sqlutil.TxStmt(txn, s.insertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { + return fmt.Errorf("s.insertCrossSigningSigsForTargetStmt: %w", err) + } + return nil +} diff --git a/keyserver/storage/sqlite3/cross_signing_sigs_table.go b/keyserver/storage/sqlite3/cross_signing_sigs_table.go index b6ffc3c9..ea077de7 100644 --- a/keyserver/storage/sqlite3/cross_signing_sigs_table.go +++ b/keyserver/storage/sqlite3/cross_signing_sigs_table.go @@ -17,6 +17,7 @@ package sqlite3 import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -41,9 +42,14 @@ const selectCrossSigningSigsForTargetSQL = "" + "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + " WHERE target_user_id = $1 AND target_key_id = $2" +const insertCrossSigningSigsForTargetSQL = "" + + "INSERT INTO keyserver_cross_signing_keys (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + + " VALUES($1, $2, $3, $4, $5)" + type crossSigningSigsStatements struct { db *sql.DB selectCrossSigningSigsForTargetStmt *sql.Stmt + insertCrossSigningSigsForTargetStmt *sql.Stmt } func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) { @@ -57,6 +63,9 @@ func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) if s.selectCrossSigningSigsForTargetStmt, err = db.Prepare(selectCrossSigningSigsForTargetSQL); err != nil { return nil, err } + if s.insertCrossSigningSigsForTargetStmt, err = db.Prepare(insertCrossSigningSigsForTargetSQL); err != nil { + return nil, err + } return s, nil } @@ -83,3 +92,15 @@ func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( } return } + +func (s *crossSigningSigsStatements) InsertCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, + originUserID string, originKeyID gomatrixserverlib.KeyID, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, + signature gomatrixserverlib.Base64Bytes, +) error { + if _, err := sqlutil.TxStmt(txn, s.insertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { + return fmt.Errorf("s.insertCrossSigningSigsForTargetStmt: %w", err) + } + return nil +} diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 516de652..0b05aeb0 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -60,6 +60,7 @@ type CrossSigningKeys interface { type CrossSigningSigs interface { SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r api.CrossSigningSigMap, err error) + InsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error } type CrossSigningStreams interface{}