mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-28 16:08:27 +00:00
Add context to the server key database (#248)
This commit is contained in:
parent
7596c19f3a
commit
fef290c47e
2 changed files with 16 additions and 7 deletions
|
@ -49,7 +49,7 @@ func (d *Database) FetchKeys(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
|
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
|
||||||
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) {
|
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) {
|
||||||
return d.statements.bulkSelectServerKeys(requests)
|
return d.statements.bulkSelectServerKeys(ctx, requests)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StoreKeys implements gomatrixserverlib.KeyDatabase
|
// StoreKeys implements gomatrixserverlib.KeyDatabase
|
||||||
|
@ -62,7 +62,7 @@ func (d *Database) StoreKeys(
|
||||||
// high for a single insert statement.
|
// high for a single insert statement.
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for request, keys := range keyMap {
|
for request, keys := range keyMap {
|
||||||
if err := d.statements.upsertServerKeys(request, keys); err != nil {
|
if err := d.statements.upsertServerKeys(ctx, request, keys); err != nil {
|
||||||
// Rather than returning immediately on error we try to insert the
|
// Rather than returning immediately on error we try to insert the
|
||||||
// remaining keys.
|
// remaining keys.
|
||||||
// Since we are inserting the keys outside of a transaction it is
|
// Since we are inserting the keys outside of a transaction it is
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
package keydb
|
package keydb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
|
@ -73,13 +74,15 @@ func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serverKeyStatements) bulkSelectServerKeys(
|
func (s *serverKeyStatements) bulkSelectServerKeys(
|
||||||
|
ctx context.Context,
|
||||||
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
|
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
|
||||||
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) {
|
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) {
|
||||||
var nameAndKeyIDs []string
|
var nameAndKeyIDs []string
|
||||||
for request := range requests {
|
for request := range requests {
|
||||||
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
|
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
|
||||||
}
|
}
|
||||||
rows, err := s.bulkSelectServerKeysStmt.Query(pq.StringArray(nameAndKeyIDs))
|
stmt := s.bulkSelectServerKeysStmt
|
||||||
|
rows, err := stmt.QueryContext(ctx, pq.StringArray(nameAndKeyIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -106,15 +109,21 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serverKeyStatements) upsertServerKeys(
|
func (s *serverKeyStatements) upsertServerKeys(
|
||||||
request gomatrixserverlib.PublicKeyRequest, keys gomatrixserverlib.ServerKeys,
|
ctx context.Context,
|
||||||
|
request gomatrixserverlib.PublicKeyRequest,
|
||||||
|
keys gomatrixserverlib.ServerKeys,
|
||||||
) error {
|
) error {
|
||||||
keyJSON, err := json.Marshal(keys)
|
keyJSON, err := json.Marshal(keys)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = s.upsertServerKeysStmt.Exec(
|
_, err = s.upsertServerKeysStmt.ExecContext(
|
||||||
string(request.ServerName), string(request.KeyID), nameAndKeyID(request),
|
ctx,
|
||||||
int64(keys.ValidUntilTS), keyJSON,
|
string(request.ServerName),
|
||||||
|
string(request.KeyID),
|
||||||
|
nameAndKeyID(request),
|
||||||
|
int64(keys.ValidUntilTS),
|
||||||
|
keyJSON,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue