Storing device keys part 1

This commit is contained in:
Neil Alexander 2021-07-29 09:48:09 +01:00
parent ad05e3de6e
commit 93bf1ffc10
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
7 changed files with 78 additions and 8 deletions

View file

@ -161,6 +161,8 @@ type PerformUploadDeviceKeysRequest struct {
gomatrixserverlib.CrossSigningKeys gomatrixserverlib.CrossSigningKeys
// The user that uploaded the key, should be populated by the clientapi. // The user that uploaded the key, should be populated by the clientapi.
UserID string `json:"user_id"` UserID string `json:"user_id"`
// The stream ID that the keys will be uploaded at
StreamID int64 `json:"stream_id"`
} }
type PerformUploadDeviceKeysResponse struct { type PerformUploadDeviceKeysResponse struct {

View file

@ -75,8 +75,23 @@ func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.P
} }
} }
// TODO: check signatures
keysToStore := api.CrossSigningKeyMap{}
for _, keyData := range req.MasterKey.Keys { // iterates once, see sanityCheckKey
keysToStore[gomatrixserverlib.CrossSigningKeyPurposeMaster] = keyData
}
for _, keyData := range req.SelfSigningKey.Keys { // iterates once, see sanityCheckKey
keysToStore[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = keyData
}
for _, keyData := range req.UserSigningKey.Keys { // iterates once, see sanityCheckKey
keysToStore[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = keyData
}
if err := a.DB.StoreCrossSigningKeysForUser(ctx, req.UserID, keysToStore, req.StreamID); err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: "Not supported yet", Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err),
}
} }
} }

View file

@ -78,4 +78,5 @@ type Database interface {
MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
CrossSigningKeysForUser(ctx context.Context, userID string) (api.CrossSigningKeyMap, error) CrossSigningKeysForUser(ctx context.Context, userID string) (api.CrossSigningKeyMap, error)
StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap api.CrossSigningKeyMap, streamID int64) error
} }

View file

@ -17,8 +17,10 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -40,9 +42,14 @@ const selectCrossSigningKeysForUserSQL = "" +
" WHERE user_id = $1" + " WHERE user_id = $1" +
" ORDER BY user_id, key_type, stream_id DESC" " ORDER BY user_id, key_type, stream_id DESC"
const insertCrossSigningKeysForUserSQL = "" +
"INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data, stream_id)" +
" VALUES($1, $2, $3, $4)"
type crossSigningKeysStatements struct { type crossSigningKeysStatements struct {
db *sql.DB db *sql.DB
selectCrossSigningKeysForUserStmt *sql.Stmt selectCrossSigningKeysForUserStmt *sql.Stmt
insertCrossSigningKeysForUserStmt *sql.Stmt
} }
func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) {
@ -56,13 +63,16 @@ func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, erro
if s.selectCrossSigningKeysForUserStmt, err = db.Prepare(selectCrossSigningKeysForUserSQL); err != nil { if s.selectCrossSigningKeysForUserStmt, err = db.Prepare(selectCrossSigningKeysForUserSQL); err != nil {
return nil, err return nil, err
} }
if s.insertCrossSigningKeysForUserStmt, err = db.Prepare(insertCrossSigningKeysForUserSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
ctx context.Context, userID string, ctx context.Context, txn *sql.Tx, userID string,
) (r api.CrossSigningKeyMap, err error) { ) (r api.CrossSigningKeyMap, err error) {
rows, err := s.selectCrossSigningKeysForUserStmt.QueryContext(ctx, userID) rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -78,3 +88,12 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
} }
return return
} }
func (s *crossSigningKeysStatements) InsertCrossSigningKeysForUser(
ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes, streamID int64,
) error {
if _, err := sqlutil.TxStmt(txn, s.insertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyType, keyData, streamID); err != nil {
return fmt.Errorf("s.insertCrossSigningKeysForUserStmt: %w", err)
}
return nil
}

View file

@ -18,6 +18,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
@ -159,5 +160,17 @@ func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isSta
// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. // CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) (api.CrossSigningKeyMap, error) { func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) (api.CrossSigningKeyMap, error) {
return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, userID) return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
}
// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user.
func (d *Database) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap api.CrossSigningKeyMap, streamID int64) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for keyType, keyData := range keyMap {
if err := d.CrossSigningKeysTable.InsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData, streamID); err != nil {
return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err)
}
}
return nil
})
} }

View file

@ -17,8 +17,10 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -40,9 +42,14 @@ const selectCrossSigningKeysForUserSQL = "" +
" (SELECT * FROM keyserver_cross_signing_keys WHERE user_id = $1 ORDER BY stream_id DESC)" + " (SELECT * FROM keyserver_cross_signing_keys WHERE user_id = $1 ORDER BY stream_id DESC)" +
" GROUP BY user_id, key_type" " GROUP BY user_id, key_type"
const insertCrossSigningKeysForUserSQL = "" +
"INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data, stream_id)" +
" VALUES($1, $2, $3, $4)"
type crossSigningKeysStatements struct { type crossSigningKeysStatements struct {
db *sql.DB db *sql.DB
selectCrossSigningKeysForUserStmt *sql.Stmt selectCrossSigningKeysForUserStmt *sql.Stmt
insertCrossSigningKeysForUserStmt *sql.Stmt
} }
func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) {
@ -56,13 +63,16 @@ func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error)
if s.selectCrossSigningKeysForUserStmt, err = db.Prepare(selectCrossSigningKeysForUserSQL); err != nil { if s.selectCrossSigningKeysForUserStmt, err = db.Prepare(selectCrossSigningKeysForUserSQL); err != nil {
return nil, err return nil, err
} }
if s.insertCrossSigningKeysForUserStmt, err = db.Prepare(insertCrossSigningKeysForUserSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
ctx context.Context, userID string, ctx context.Context, txn *sql.Tx, userID string,
) (r api.CrossSigningKeyMap, err error) { ) (r api.CrossSigningKeyMap, err error) {
rows, err := s.selectCrossSigningKeysForUserStmt.QueryContext(ctx, userID) rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -78,3 +88,12 @@ func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
} }
return return
} }
func (s *crossSigningKeysStatements) InsertCrossSigningKeysForUser(
ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes, streamID int64,
) error {
if _, err := sqlutil.TxStmt(txn, s.insertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyType, keyData, streamID); err != nil {
return fmt.Errorf("s.insertCrossSigningKeysForUserStmt: %w", err)
}
return nil
}

View file

@ -54,7 +54,8 @@ type StaleDeviceLists interface {
} }
type CrossSigningKeys interface { type CrossSigningKeys interface {
SelectCrossSigningKeysForUser(ctx context.Context, userID string) (r api.CrossSigningKeyMap, err error) SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r api.CrossSigningKeyMap, err error)
InsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes, streamID int64) error
} }
type CrossSigningSigs interface{} type CrossSigningSigs interface{}