Produce OTK counts in /sync response (#1235)

* Add QueryOneTimeKeys for /sync extensions

* Unbreak tests

* Produce OTK counts in /sync response

* Linting
This commit is contained in:
Kegsay 2020-08-03 12:29:58 +01:00 committed by GitHub
parent b5cb1d1534
commit ffcb6d2ea1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 138 additions and 7 deletions

View file

@ -29,6 +29,9 @@ type Database interface {
// StoreOneTimeKeys persists the given one-time keys.
StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
// OneTimeKeysCount returns a count of all OTKs for this device.
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error

View file

@ -121,6 +121,28 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d
return result, rows.Err()
}
func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
counts := &api.OneTimeKeysCount{
DeviceID: deviceID,
UserID: userID,
KeyCount: make(map[string]int),
}
rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
for rows.Next() {
var algorithm string
var count int
if err = rows.Scan(&algorithm, &count); err != nil {
return nil, err
}
counts.KeyCount[algorithm] = count
}
return counts, nil
}
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
now := time.Now().Unix()
counts := &api.OneTimeKeysCount{

View file

@ -39,6 +39,10 @@ func (d *Database) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (
return d.OneTimeKeysTable.InsertOneTimeKeys(ctx, keys)
}
func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
}
func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error {
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
}

View file

@ -121,6 +121,28 @@ func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, d
return result, rows.Err()
}
func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
counts := &api.OneTimeKeysCount{
DeviceID: deviceID,
UserID: userID,
KeyCount: make(map[string]int),
}
rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
for rows.Next() {
var algorithm string
var count int
if err = rows.Scan(&algorithm, &count); err != nil {
return nil, err
}
counts.KeyCount[algorithm] = count
}
return counts, nil
}
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
now := time.Now().Unix()
counts := &api.OneTimeKeysCount{

View file

@ -24,6 +24,7 @@ import (
type OneTimeKeys interface {
SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
// SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON.
// Returns an empty map if the key does not exist.