Add context.Context to the federation client (#225)

* Add context.Context to the federation client

* gb vendor update github.com/matrix-org/gomatrixserverlib
This commit is contained in:
Mark Haines 2017-09-13 11:03:41 +01:00 committed by GitHub
parent 086683459f
commit 029e71828a
17 changed files with 139 additions and 72 deletions

2
vendor/manifest vendored
View file

@ -116,7 +116,7 @@
{
"importpath": "github.com/matrix-org/gomatrixserverlib",
"repository": "https://github.com/matrix-org/gomatrixserverlib",
"revision": "790f02e8f465552dab4317ffe7ca047ccb594cbf",
"revision": "ec5a0d21b03ed4d3bd955ecc9f7a69936f64391e",
"branch": "master"
},
{

View file

@ -17,6 +17,7 @@ package gomatrixserverlib
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
@ -103,7 +104,9 @@ func (f *federationTripper) RoundTrip(r *http.Request) (*http.Response, error) {
// LookupUserInfo gets information about a user from a given matrix homeserver
// using a bearer access token.
func (fc *Client) LookupUserInfo(matrixServer ServerName, token string) (u UserInfo, err error) {
func (fc *Client) LookupUserInfo(
ctx context.Context, matrixServer ServerName, token string,
) (u UserInfo, err error) {
url := url.URL{
Scheme: "matrix",
Host: string(matrixServer),
@ -111,8 +114,13 @@ func (fc *Client) LookupUserInfo(matrixServer ServerName, token string) (u UserI
RawQuery: url.Values{"access_token": []string{token}}.Encode(),
}
req, err := http.NewRequest("GET", url.String(), nil)
if err != nil {
return
}
var response *http.Response
response, err = fc.client.Get(url.String())
response, err = fc.client.Do(req.WithContext(ctx))
if response != nil {
defer response.Body.Close() // nolint: errcheck
}
@ -153,7 +161,7 @@ func (fc *Client) LookupUserInfo(matrixServer ServerName, token string) (u UserI
// copy of the keys.
// Returns the keys or an error if there was a problem talking to the server.
func (fc *Client) LookupServerKeys( // nolint: gocyclo
matrixServer ServerName, keyRequests map[PublicKeyRequest]Timestamp,
ctx context.Context, matrixServer ServerName, keyRequests map[PublicKeyRequest]Timestamp,
) (map[PublicKeyRequest]ServerKeys, error) {
url := url.URL{
Scheme: "matrix",
@ -183,7 +191,13 @@ func (fc *Client) LookupServerKeys( // nolint: gocyclo
return nil, err
}
response, err := fc.client.Post(url.String(), "application/json", bytes.NewBuffer(requestBytes))
req, err := http.NewRequest("POST", url.String(), bytes.NewBuffer(requestBytes))
if err != nil {
return nil, err
}
req.Header.Add("Content-Type", "application/json")
response, err := fc.client.Do(req.WithContext(ctx))
if response != nil {
defer response.Body.Close() // nolint: errcheck
}

View file

@ -17,6 +17,7 @@ package gomatrixserverlib
import (
"bytes"
"context"
"crypto/sha256"
"encoding/json"
"fmt"
@ -188,7 +189,7 @@ func verifyEventSignature(signingName string, keyID KeyID, publicKey ed25519.Pub
// VerifyEventSignatures checks that each event in a list of events has valid
// signatures from the server that sent it.
func VerifyEventSignatures(events []Event, keyRing KeyRing) error { // nolint: gocyclo
func VerifyEventSignatures(ctx context.Context, events []Event, keyRing KeyRing) error { // nolint: gocyclo
var toVerify []VerifyJSONRequest
for _, event := range events {
redactedJSON, err := redactEvent(event.eventJSON)
@ -222,7 +223,7 @@ func VerifyEventSignatures(events []Event, keyRing KeyRing) error { // nolint: g
}
}
results, err := keyRing.VerifyJSONs(toVerify)
results, err := keyRing.VerifyJSONs(ctx, toVerify)
if err != nil {
return err
}

View file

@ -1,6 +1,7 @@
package gomatrixserverlib
import (
"context"
"encoding/json"
"io/ioutil"
"net/http"
@ -31,7 +32,7 @@ func NewFederationClient(
}
}
func (ac *FederationClient) doRequest(r FederationRequest, resBody interface{}) error {
func (ac *FederationClient) doRequest(ctx context.Context, r FederationRequest, resBody interface{}) error {
if err := r.Sign(ac.serverName, ac.serverKeyID, ac.serverPrivateKey); err != nil {
return err
}
@ -41,7 +42,7 @@ func (ac *FederationClient) doRequest(r FederationRequest, resBody interface{})
return err
}
res, err := ac.client.Do(req)
res, err := ac.client.Do(req.WithContext(ctx))
if res != nil {
defer res.Body.Close() // nolint: errcheck
}
@ -87,13 +88,15 @@ func (ac *FederationClient) doRequest(r FederationRequest, resBody interface{})
var federationPathPrefix = "/_matrix/federation/v1"
// SendTransaction sends a transaction
func (ac *FederationClient) SendTransaction(t Transaction) (res RespSend, err error) {
func (ac *FederationClient) SendTransaction(
ctx context.Context, t Transaction,
) (res RespSend, err error) {
path := federationPathPrefix + "/send/" + string(t.TransactionID) + "/"
req := NewFederationRequest("PUT", t.Destination, path)
if err = req.SetContent(t); err != nil {
return
}
err = ac.doRequest(req, &res)
err = ac.doRequest(ctx, req, &res)
return
}
@ -106,12 +109,14 @@ func (ac *FederationClient) SendTransaction(t Transaction) (res RespSend, err er
// If this successfully returns an acceptable event we will sign it with our
// server's key and pass it to SendJoin.
// See https://matrix.org/docs/spec/server_server/unstable.html#joining-rooms
func (ac *FederationClient) MakeJoin(s ServerName, roomID, userID string) (res RespMakeJoin, err error) {
func (ac *FederationClient) MakeJoin(
ctx context.Context, s ServerName, roomID, userID string,
) (res RespMakeJoin, err error) {
path := federationPathPrefix + "/make_join/" +
url.PathEscape(roomID) + "/" +
url.PathEscape(userID)
req := NewFederationRequest("GET", s, path)
err = ac.doRequest(req, &res)
err = ac.doRequest(ctx, req, &res)
return
}
@ -119,7 +124,9 @@ func (ac *FederationClient) MakeJoin(s ServerName, roomID, userID string) (res R
// remote matrix server.
// This is used to join a room the local server isn't a member of.
// See https://matrix.org/docs/spec/server_server/unstable.html#joining-rooms
func (ac *FederationClient) SendJoin(s ServerName, event Event) (res RespSendJoin, err error) {
func (ac *FederationClient) SendJoin(
ctx context.Context, s ServerName, event Event,
) (res RespSendJoin, err error) {
path := federationPathPrefix + "/send_join/" +
url.PathEscape(event.RoomID()) + "/" +
url.PathEscape(event.EventID())
@ -127,13 +134,15 @@ func (ac *FederationClient) SendJoin(s ServerName, event Event) (res RespSendJoi
if err = req.SetContent(event); err != nil {
return
}
err = ac.doRequest(req, &res)
err = ac.doRequest(ctx, req, &res)
return
}
// SendInvite sends an invite m.room.member event to an invited server to be
// signed by it. This is used to invite a user that is not on the local server.
func (ac *FederationClient) SendInvite(s ServerName, event Event) (res RespInvite, err error) {
func (ac *FederationClient) SendInvite(
ctx context.Context, s ServerName, event Event,
) (res RespInvite, err error) {
path := federationPathPrefix + "/invite/" +
url.PathEscape(event.RoomID()) + "/" +
url.PathEscape(event.EventID())
@ -141,7 +150,7 @@ func (ac *FederationClient) SendInvite(s ServerName, event Event) (res RespInvit
if err = req.SetContent(event); err != nil {
return
}
err = ac.doRequest(req, &res)
err = ac.doRequest(ctx, req, &res)
return
}
@ -150,38 +159,44 @@ func (ac *FederationClient) SendInvite(s ServerName, event Event) (res RespInvit
// server.
// This is used to exchange a m.room.third_party_invite event for a m.room.member
// one in a room the local server isn't a member of.
func (ac *FederationClient) ExchangeThirdPartyInvite(s ServerName, builder EventBuilder) (err error) {
func (ac *FederationClient) ExchangeThirdPartyInvite(
ctx context.Context, s ServerName, builder EventBuilder,
) (err error) {
path := federationPathPrefix + "/exchange_third_party_invite/" +
url.PathEscape(builder.RoomID)
req := NewFederationRequest("PUT", s, path)
if err = req.SetContent(builder); err != nil {
return
}
err = ac.doRequest(req, nil)
err = ac.doRequest(ctx, req, nil)
return
}
// LookupState retrieves the room state for a room at an event from a
// remote matrix server as full matrix events.
func (ac *FederationClient) LookupState(s ServerName, roomID, eventID string) (res RespState, err error) {
func (ac *FederationClient) LookupState(
ctx context.Context, s ServerName, roomID, eventID string,
) (res RespState, err error) {
path := federationPathPrefix + "/state/" +
url.PathEscape(roomID) +
"/?event_id=" +
url.QueryEscape(eventID)
req := NewFederationRequest("GET", s, path)
err = ac.doRequest(req, &res)
err = ac.doRequest(ctx, req, &res)
return
}
// LookupStateIDs retrieves the room state for a room at an event from a
// remote matrix server as lists of matrix event IDs.
func (ac *FederationClient) LookupStateIDs(s ServerName, roomID, eventID string) (res RespStateIDs, err error) {
func (ac *FederationClient) LookupStateIDs(
ctx context.Context, s ServerName, roomID, eventID string,
) (res RespStateIDs, err error) {
path := federationPathPrefix + "/state_ids/" +
url.PathEscape(roomID) +
"/?event_id=" +
url.QueryEscape(eventID)
req := NewFederationRequest("GET", s, path)
err = ac.doRequest(req, &res)
err = ac.doRequest(ctx, req, &res)
return
}
@ -190,10 +205,12 @@ func (ac *FederationClient) LookupStateIDs(s ServerName, roomID, eventID string)
// being looked up on.
// If the room alias doesn't exist on the remote server then a 404 gomatrix.HTTPError
// is returned.
func (ac *FederationClient) LookupRoomAlias(s ServerName, roomAlias string) (res RespDirectory, err error) {
func (ac *FederationClient) LookupRoomAlias(
ctx context.Context, s ServerName, roomAlias string,
) (res RespDirectory, err error) {
path := federationPathPrefix + "/query/directory?room_alias=" +
url.QueryEscape(roomAlias)
req := NewFederationRequest("GET", s, path)
err = ac.doRequest(req, &res)
err = ac.doRequest(ctx, req, &res)
return
}

View file

@ -1,6 +1,7 @@
package gomatrixserverlib
import (
"context"
"encoding/json"
"fmt"
)
@ -107,7 +108,7 @@ func (r RespState) Events() ([]Event, error) {
}
// Check that a response to /state is valid.
func (r RespState) Check(keyRing KeyRing) error {
func (r RespState) Check(ctx context.Context, keyRing KeyRing) error {
var allEvents []Event
for _, event := range r.AuthEvents {
if event.StateKey() == nil {
@ -133,7 +134,7 @@ func (r RespState) Check(keyRing KeyRing) error {
}
// Check if the events pass signature checks.
if err := VerifyEventSignatures(allEvents, keyRing); err != nil {
if err := VerifyEventSignatures(ctx, allEvents, keyRing); err != nil {
return nil
}
@ -213,11 +214,11 @@ type respSendJoinFields struct {
// Check that a response to /send_join is valid.
// This checks that it would be valid as a response to /state
// This also checks that the join event is allowed by the state.
func (r RespSendJoin) Check(keyRing KeyRing, joinEvent Event) error {
func (r RespSendJoin) Check(ctx context.Context, keyRing KeyRing, joinEvent Event) error {
// First check that the state is valid.
// The response to /send_join has the same data as a response to /state
// and the checks for a response to /state also apply.
if err := RespState(r).Check(keyRing); err != nil {
if err := RespState(r).Check(ctx, keyRing); err != nil {
return err
}

View file

@ -6,13 +6,13 @@ echo "Installing lint search engine..."
go get github.com/alecthomas/gometalinter/
gometalinter --config=linter.json --install --update
echo "Testing..."
go test
echo "Looking for lint..."
gometalinter --config=linter.json
echo "Double checking spelling..."
misspell -error src *.md
echo "Testing..."
go test
echo "Done!"

View file

@ -1,6 +1,7 @@
package gomatrixserverlib
import (
"context"
"fmt"
"strings"
"time"
@ -26,7 +27,7 @@ type KeyFetcher interface {
// The result may have fewer (server name, key ID) pairs than were in the request.
// The result may have more (server name, key ID) pairs than were in the request.
// Returns an error if there was a problem fetching the keys.
FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error)
FetchKeys(ctx context.Context, requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error)
}
// A KeyDatabase is a store for caching public keys.
@ -39,7 +40,7 @@ type KeyDatabase interface {
// to a concurrent FetchKeys(). This is acceptable since the database is
// only used as a cache for the keys, so if a FetchKeys() races with a
// StoreKeys() and some of the keys are missing they will be just be refetched.
StoreKeys(map[PublicKeyRequest]ServerKeys) error
StoreKeys(ctx context.Context, results map[PublicKeyRequest]ServerKeys) error
}
// A KeyRing stores keys for matrix servers and provides methods for verifying JSON messages.
@ -73,7 +74,7 @@ type VerifyJSONResult struct {
// The caller should check the Result field for each entry to see if it was valid.
// Returns an error if there was a problem talking to the database or one of the other methods
// of fetching the public keys.
func (k *KeyRing) VerifyJSONs(requests []VerifyJSONRequest) ([]VerifyJSONResult, error) { // nolint: gocyclo
func (k *KeyRing) VerifyJSONs(ctx context.Context, requests []VerifyJSONRequest) ([]VerifyJSONResult, error) { // nolint: gocyclo
results := make([]VerifyJSONResult, len(requests))
keyIDs := make([][]KeyID, len(requests))
@ -109,7 +110,7 @@ func (k *KeyRing) VerifyJSONs(requests []VerifyJSONRequest) ([]VerifyJSONResult,
// This will happen if all the objects are missing supported signatures.
return results, nil
}
keysFromDatabase, err := k.KeyDatabase.FetchKeys(keyRequests)
keysFromDatabase, err := k.KeyDatabase.FetchKeys(ctx, keyRequests)
if err != nil {
return nil, err
}
@ -124,14 +125,14 @@ func (k *KeyRing) VerifyJSONs(requests []VerifyJSONRequest) ([]VerifyJSONResult,
}
// TODO: Coalesce in-flight requests for the same keys.
// Otherwise we risk spamming the servers we query the keys from.
keysFetched, err := k.KeyFetchers[i].FetchKeys(keyRequests)
keysFetched, err := k.KeyFetchers[i].FetchKeys(ctx, keyRequests)
if err != nil {
return nil, err
}
k.checkUsingKeys(requests, results, keyIDs, keysFetched)
// Add the keys to the database so that we won't need to fetch them again.
if err := k.KeyDatabase.StoreKeys(keysFetched); err != nil {
if err := k.KeyDatabase.StoreKeys(ctx, keysFetched); err != nil {
return nil, err
}
}
@ -143,7 +144,9 @@ func (k *KeyRing) isAlgorithmSupported(keyID KeyID) bool {
return strings.HasPrefix(string(keyID), "ed25519:")
}
func (k *KeyRing) publicKeyRequests(requests []VerifyJSONRequest, results []VerifyJSONResult, keyIDs [][]KeyID) map[PublicKeyRequest]Timestamp {
func (k *KeyRing) publicKeyRequests(
requests []VerifyJSONRequest, results []VerifyJSONResult, keyIDs [][]KeyID,
) map[PublicKeyRequest]Timestamp {
keyRequests := map[PublicKeyRequest]Timestamp{}
for i := range requests {
if results[i].Error == nil {
@ -218,8 +221,10 @@ type PerspectiveKeyFetcher struct {
}
// FetchKeys implements KeyFetcher
func (p *PerspectiveKeyFetcher) FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error) {
results, err := p.Client.LookupServerKeys(p.PerspectiveServerName, requests)
func (p *PerspectiveKeyFetcher) FetchKeys(
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
) (map[PublicKeyRequest]ServerKeys, error) {
results, err := p.Client.LookupServerKeys(ctx, p.PerspectiveServerName, requests)
if err != nil {
return nil, err
}
@ -269,7 +274,9 @@ type DirectKeyFetcher struct {
}
// FetchKeys implements KeyFetcher
func (d *DirectKeyFetcher) FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error) {
func (d *DirectKeyFetcher) FetchKeys(
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
) (map[PublicKeyRequest]ServerKeys, error) {
byServer := map[ServerName]map[PublicKeyRequest]Timestamp{}
for req, ts := range requests {
server := byServer[req.ServerName]
@ -283,7 +290,7 @@ func (d *DirectKeyFetcher) FetchKeys(requests map[PublicKeyRequest]Timestamp) (m
results := map[PublicKeyRequest]ServerKeys{}
for server, reqs := range byServer {
// TODO: make these requests in parallel
serverResults, err := d.fetchKeysForServer(server, reqs)
serverResults, err := d.fetchKeysForServer(ctx, server, reqs)
if err != nil {
// TODO: Should we actually be erroring here? or should we just drop those keys from the result map?
return nil, err
@ -296,9 +303,9 @@ func (d *DirectKeyFetcher) FetchKeys(requests map[PublicKeyRequest]Timestamp) (m
}
func (d *DirectKeyFetcher) fetchKeysForServer(
serverName ServerName, requests map[PublicKeyRequest]Timestamp,
ctx context.Context, serverName ServerName, requests map[PublicKeyRequest]Timestamp,
) (map[PublicKeyRequest]ServerKeys, error) {
results, err := d.Client.LookupServerKeys(serverName, requests)
results, err := d.Client.LookupServerKeys(ctx, serverName, requests)
if err != nil {
return nil, err
}

View file

@ -1,6 +1,7 @@
package gomatrixserverlib
import (
"context"
"encoding/json"
"testing"
)
@ -36,7 +37,9 @@ var testKeys = `{
type testKeyDatabase struct{}
func (db *testKeyDatabase) FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error) {
func (db *testKeyDatabase) FetchKeys(
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
) (map[PublicKeyRequest]ServerKeys, error) {
results := map[PublicKeyRequest]ServerKeys{}
var keys ServerKeys
if err := json.Unmarshal([]byte(testKeys), &keys); err != nil {
@ -54,14 +57,16 @@ func (db *testKeyDatabase) FetchKeys(requests map[PublicKeyRequest]Timestamp) (m
return results, nil
}
func (db *testKeyDatabase) StoreKeys(requests map[PublicKeyRequest]ServerKeys) error {
func (db *testKeyDatabase) StoreKeys(
ctx context.Context, requests map[PublicKeyRequest]ServerKeys,
) error {
return nil
}
func TestVerifyJSONsSuccess(t *testing.T) {
// Check that trying to verify the server key JSON works.
k := KeyRing{nil, &testKeyDatabase{}}
results, err := k.VerifyJSONs([]VerifyJSONRequest{{
results, err := k.VerifyJSONs(context.Background(), []VerifyJSONRequest{{
ServerName: "localhost:8800",
Message: []byte(testKeys),
AtTS: 1493142432964,
@ -77,7 +82,7 @@ func TestVerifyJSONsSuccess(t *testing.T) {
func TestVerifyJSONsUnknownServerFails(t *testing.T) {
// Check that trying to verify JSON for an unknown server fails.
k := KeyRing{nil, &testKeyDatabase{}}
results, err := k.VerifyJSONs([]VerifyJSONRequest{{
results, err := k.VerifyJSONs(context.Background(), []VerifyJSONRequest{{
ServerName: "unknown:8800",
Message: []byte(testKeys),
AtTS: 1493142432964,
@ -94,7 +99,7 @@ func TestVerifyJSONsDistantFutureFails(t *testing.T) {
// Check that trying to verify JSON from the distant future fails.
distantFuture := Timestamp(2000000000000)
k := KeyRing{nil, &testKeyDatabase{}}
results, err := k.VerifyJSONs([]VerifyJSONRequest{{
results, err := k.VerifyJSONs(context.Background(), []VerifyJSONRequest{{
ServerName: "unknown:8800",
Message: []byte(testKeys),
AtTS: distantFuture,
@ -110,7 +115,7 @@ func TestVerifyJSONsDistantFutureFails(t *testing.T) {
func TestVerifyJSONsFetcherError(t *testing.T) {
// Check that if the database errors then the attempt to verify JSON fails.
k := KeyRing{nil, &erroringKeyDatabase{}}
results, err := k.VerifyJSONs([]VerifyJSONRequest{{
results, err := k.VerifyJSONs(context.Background(), []VerifyJSONRequest{{
ServerName: "localhost:8800",
Message: []byte(testKeys),
AtTS: 1493142432964,
@ -129,10 +134,14 @@ func (e *erroringKeyDatabaseError) Error() string { return "An error with the ke
var testErrorFetch = erroringKeyDatabaseError(1)
var testErrorStore = erroringKeyDatabaseError(2)
func (e *erroringKeyDatabase) FetchKeys(requests map[PublicKeyRequest]Timestamp) (map[PublicKeyRequest]ServerKeys, error) {
func (e *erroringKeyDatabase) FetchKeys(
ctx context.Context, requests map[PublicKeyRequest]Timestamp,
) (map[PublicKeyRequest]ServerKeys, error) {
return nil, &testErrorFetch
}
func (e *erroringKeyDatabase) StoreKeys(keys map[PublicKeyRequest]ServerKeys) error {
func (e *erroringKeyDatabase) StoreKeys(
ctx context.Context, keys map[PublicKeyRequest]ServerKeys,
) error {
return &testErrorStore
}

View file

@ -1,4 +1,5 @@
{
"Deadline": "5m",
"Enable": [
"vet",
"vetshadow",

View file

@ -215,7 +215,7 @@ func VerifyHTTPRequest(
return nil, util.MessageResponse(401, message)
}
results, err := keys.VerifyJSONs([]VerifyJSONRequest{{
results, err := keys.VerifyJSONs(req.Context(), []VerifyJSONRequest{{
ServerName: request.Origin(),
AtTS: AsTimestamp(now),
Message: toVerify,