mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-07-31 13:22:46 +00:00
Implement claiming one-time keys locally (#1210)
* Add API shape for claiming keys * Implement claiming one-time keys locally Fairly boring, nothing too special going on.
This commit is contained in:
parent
d76eb1b994
commit
1d72ce8b7a
10 changed files with 202 additions and 8 deletions
|
@ -39,4 +39,8 @@ type Database interface {
|
|||
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
|
||||
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
|
||||
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
|
||||
|
||||
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
|
||||
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
|
||||
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
|
||||
}
|
||||
|
|
|
@ -52,11 +52,19 @@ const selectKeysSQL = "" +
|
|||
const selectKeysCountSQL = "" +
|
||||
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
|
||||
|
||||
const deleteOneTimeKeySQL = "" +
|
||||
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
|
||||
|
||||
const selectKeyByAlgorithmSQL = "" +
|
||||
"SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
|
||||
|
||||
type oneTimeKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertKeysStmt *sql.Stmt
|
||||
selectKeysStmt *sql.Stmt
|
||||
selectKeysCountStmt *sql.Stmt
|
||||
db *sql.DB
|
||||
upsertKeysStmt *sql.Stmt
|
||||
selectKeysStmt *sql.Stmt
|
||||
selectKeysCountStmt *sql.Stmt
|
||||
selectKeyByAlgorithmStmt *sql.Stmt
|
||||
deleteOneTimeKeyStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
||||
|
@ -76,6 +84,12 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
|||
if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
|
@ -141,3 +155,21 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.
|
|||
return rows.Err()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
||||
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
|
||||
) (map[string]json.RawMessage, error) {
|
||||
var keyID string
|
||||
var keyJSON string
|
||||
err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||
return map[string]json.RawMessage{
|
||||
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||
}, err
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||
)
|
||||
|
@ -48,3 +49,26 @@ func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) e
|
|||
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
|
||||
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
|
||||
}
|
||||
|
||||
func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
|
||||
var result []api.OneTimeKeys
|
||||
err := sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
|
||||
for userID, deviceToAlgo := range userToDeviceToAlgorithm {
|
||||
for deviceID, algo := range deviceToAlgo {
|
||||
keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if keyJSON != nil {
|
||||
result = append(result, api.OneTimeKeys{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
KeyJSON: keyJSON,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return result, err
|
||||
}
|
||||
|
|
|
@ -52,11 +52,19 @@ const selectKeysSQL = "" +
|
|||
const selectKeysCountSQL = "" +
|
||||
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
|
||||
|
||||
const deleteOneTimeKeySQL = "" +
|
||||
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
|
||||
|
||||
const selectKeyByAlgorithmSQL = "" +
|
||||
"SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
|
||||
|
||||
type oneTimeKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertKeysStmt *sql.Stmt
|
||||
selectKeysStmt *sql.Stmt
|
||||
selectKeysCountStmt *sql.Stmt
|
||||
db *sql.DB
|
||||
upsertKeysStmt *sql.Stmt
|
||||
selectKeysStmt *sql.Stmt
|
||||
selectKeysCountStmt *sql.Stmt
|
||||
selectKeyByAlgorithmStmt *sql.Stmt
|
||||
deleteOneTimeKeyStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
||||
|
@ -76,6 +84,12 @@ func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
|||
if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
|
@ -141,3 +155,21 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.
|
|||
return rows.Err()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
||||
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
|
||||
) (map[string]json.RawMessage, error) {
|
||||
var keyID string
|
||||
var keyJSON string
|
||||
err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||
return map[string]json.RawMessage{
|
||||
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||
}, err
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ package tables
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
|
@ -24,6 +25,9 @@ import (
|
|||
type OneTimeKeys interface {
|
||||
SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||
InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
|
||||
// SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON.
|
||||
// Returns an empty map if the key does not exist.
|
||||
SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error)
|
||||
}
|
||||
|
||||
type DeviceKeys interface {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue