Add context to the server key database (#248)

This commit is contained in:
Mark Haines 2017-09-21 16:16:02 +01:00 committed by GitHub
parent 7596c19f3a
commit fef290c47e
2 changed files with 16 additions and 7 deletions

View file

@ -49,7 +49,7 @@ func (d *Database) FetchKeys(
ctx context.Context, ctx context.Context,
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp, requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) { ) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) {
return d.statements.bulkSelectServerKeys(requests) return d.statements.bulkSelectServerKeys(ctx, requests)
} }
// StoreKeys implements gomatrixserverlib.KeyDatabase // StoreKeys implements gomatrixserverlib.KeyDatabase
@ -62,7 +62,7 @@ func (d *Database) StoreKeys(
// high for a single insert statement. // high for a single insert statement.
var lastErr error var lastErr error
for request, keys := range keyMap { for request, keys := range keyMap {
if err := d.statements.upsertServerKeys(request, keys); err != nil { if err := d.statements.upsertServerKeys(ctx, request, keys); err != nil {
// Rather than returning immediately on error we try to insert the // Rather than returning immediately on error we try to insert the
// remaining keys. // remaining keys.
// Since we are inserting the keys outside of a transaction it is // Since we are inserting the keys outside of a transaction it is

View file

@ -15,6 +15,7 @@
package keydb package keydb
import ( import (
"context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
@ -73,13 +74,15 @@ func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
} }
func (s *serverKeyStatements) bulkSelectServerKeys( func (s *serverKeyStatements) bulkSelectServerKeys(
ctx context.Context,
requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp, requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) { ) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) {
var nameAndKeyIDs []string var nameAndKeyIDs []string
for request := range requests { for request := range requests {
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request)) nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
} }
rows, err := s.bulkSelectServerKeysStmt.Query(pq.StringArray(nameAndKeyIDs)) stmt := s.bulkSelectServerKeysStmt
rows, err := stmt.QueryContext(ctx, pq.StringArray(nameAndKeyIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -106,15 +109,21 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
} }
func (s *serverKeyStatements) upsertServerKeys( func (s *serverKeyStatements) upsertServerKeys(
request gomatrixserverlib.PublicKeyRequest, keys gomatrixserverlib.ServerKeys, ctx context.Context,
request gomatrixserverlib.PublicKeyRequest,
keys gomatrixserverlib.ServerKeys,
) error { ) error {
keyJSON, err := json.Marshal(keys) keyJSON, err := json.Marshal(keys)
if err != nil { if err != nil {
return err return err
} }
_, err = s.upsertServerKeysStmt.Exec( _, err = s.upsertServerKeysStmt.ExecContext(
string(request.ServerName), string(request.KeyID), nameAndKeyID(request), ctx,
int64(keys.ValidUntilTS), keyJSON, string(request.ServerName),
string(request.KeyID),
nameAndKeyID(request),
int64(keys.ValidUntilTS),
keyJSON,
) )
return err return err
} }