From ce82158abb565b31d70628cc2271e3e4bb2fa4be Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Fri, 2 Jun 2017 11:19:34 +0100 Subject: [PATCH] Add API for querying events by ID. (#127) * Add API for querying events by ID. * Fix tense * Start implementing federation ingress * More stuff * Hook up federation event receiving * Fix comments * Comment on the order of the arrays --- .../dendrite-federation-api-server/main.go | 53 ++++- .../dendrite/cmd/federation-api-proxy/main.go | 124 ++++++++++++ .../dendrite/federationapi/routing/routing.go | 28 ++- .../dendrite/federationapi/writers/send.go | 182 ++++++++++++++++++ .../dendrite/roomserver/api/query.go | 40 ++++ .../dendrite/roomserver/query/query.go | 50 ++++- .../roomserver/storage/events_table.go | 25 +++ .../dendrite/roomserver/storage/storage.go | 5 + 8 files changed, 501 insertions(+), 6 deletions(-) create mode 100644 src/github.com/matrix-org/dendrite/cmd/federation-api-proxy/main.go create mode 100644 src/github.com/matrix-org/dendrite/federationapi/writers/send.go diff --git a/src/github.com/matrix-org/dendrite/cmd/dendrite-federation-api-server/main.go b/src/github.com/matrix-org/dendrite/cmd/dendrite-federation-api-server/main.go index 15bf5bc4..642388df 100644 --- a/src/github.com/matrix-org/dendrite/cmd/dendrite-federation-api-server/main.go +++ b/src/github.com/matrix-org/dendrite/cmd/dendrite-federation-api-server/main.go @@ -18,11 +18,14 @@ import ( "encoding/base64" "net/http" "os" + "strings" "time" + "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/federationapi/config" "github.com/matrix-org/dendrite/federationapi/routing" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" log "github.com/Sirupsen/logrus" @@ -40,7 +43,10 @@ var ( // openssl x509 -noout -fingerprint -sha256 -inform pem -in server.crt |\ // python -c 'print raw_input()[19:].replace(":","").decode("hex").encode("base64").rstrip("=\n")' // - tlsFingerprint = os.Getenv("TLS_FINGERPRINT") + tlsFingerprint = os.Getenv("TLS_FINGERPRINT") + kafkaURIs = strings.Split(os.Getenv("KAFKA_URIS"), ",") + roomserverURL = os.Getenv("ROOMSERVER_URL") + roomserverInputTopic = os.Getenv("TOPIC_INPUT_ROOM_EVENT") ) func main() { @@ -57,6 +63,18 @@ func main() { log.Panic("No TLS_FINGERPRINT environment variable found.") } + if len(kafkaURIs) == 0 { + // the kafka default is :9092 + kafkaURIs = []string{"localhost:9092"} + } + + if roomserverURL == "" { + log.Panic("No ROOMSERVER_URL environment variable found.") + } + + if roomserverInputTopic == "" { + log.Panic("No TOPIC_INPUT_ROOM_EVENT environment variable found. This should match the roomserver input topic.") + } cfg := config.FederationAPI{ ServerName: serverName, // TODO: make the validity period configurable. @@ -75,6 +93,37 @@ func main() { } cfg.TLSFingerPrints = []gomatrixserverlib.TLSFingerprint{{fingerprintSHA256}} - routing.Setup(http.DefaultServeMux, cfg) + federation := gomatrixserverlib.NewFederationClient(cfg.ServerName, cfg.KeyID, cfg.PrivateKey) + + keyRing := gomatrixserverlib.KeyRing{ + KeyFetchers: []gomatrixserverlib.KeyFetcher{ + // TODO: Use perspective key fetchers for production. + &gomatrixserverlib.DirectKeyFetcher{federation.Client}, + }, + KeyDatabase: &dummyKeyDatabase{}, + } + queryAPI := api.NewRoomserverQueryAPIHTTP(roomserverURL, nil) + + roomserverProducer, err := producers.NewRoomserverProducer(kafkaURIs, roomserverInputTopic) + if err != nil { + log.Panicf("Failed to setup kafka producers(%s): %s", kafkaURIs, err) + } + + routing.Setup(http.DefaultServeMux, cfg, queryAPI, roomserverProducer, keyRing) log.Fatal(http.ListenAndServe(bindAddr, nil)) } + +// TODO: Implement a proper key database. +type dummyKeyDatabase struct{} + +func (d *dummyKeyDatabase) FetchKeys( + requests map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.Timestamp, +) (map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, error) { + return nil, nil +} + +func (d *dummyKeyDatabase) StoreKeys( + map[gomatrixserverlib.PublicKeyRequest]gomatrixserverlib.ServerKeys, +) error { + return nil +} diff --git a/src/github.com/matrix-org/dendrite/cmd/federation-api-proxy/main.go b/src/github.com/matrix-org/dendrite/cmd/federation-api-proxy/main.go new file mode 100644 index 00000000..d1cba69e --- /dev/null +++ b/src/github.com/matrix-org/dendrite/cmd/federation-api-proxy/main.go @@ -0,0 +1,124 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "flag" + "fmt" + log "github.com/Sirupsen/logrus" + "net/http" + "net/http/httputil" + "net/url" + "os" + "strings" + "time" +) + +const usage = `Usage: %s + +Create a single endpoint URL which remote matrix servers can be pointed at. + +The server-server API in Dendrite is split across multiple processes +which listen on multiple ports. You cannot point a Matrix server at +any of those ports, as there will be unimplemented functionality. +In addition, all server-server API processes start with the additional +path prefix '/api', which Matrix servers will be unaware of. + +This tool will proxy requests for all server-server URLs and forward +them to their respective process. It will also add the '/api' path +prefix to incoming requests. + +THIS TOOL IS FOR TESTING AND NOT INTENDED FOR PRODUCTION USE. + +Arguments: + +` + +var ( + federationAPIURL = flag.String("federation-api-url", "", "The base URL of the listening 'dendrite-federation-api-server' process. E.g. 'http://localhost:4200'") + bindAddress = flag.String("bind-address", ":8448", "The listening port for the proxy.") + certFile = flag.String("tls-cert", "server.crt", "The X509 certificate to use for TLS") + keyFile = flag.String("tls-key", "server.key", "The PEM private key to use for TLS") +) + +func makeProxy(targetURL string) (*httputil.ReverseProxy, error) { + if !strings.HasSuffix(targetURL, "/") { + targetURL += "/" + } + // Check that we can parse the URL. + _, err := url.Parse(targetURL) + if err != nil { + return nil, err + } + return &httputil.ReverseProxy{ + Director: func(req *http.Request) { + // URL.Path() removes the % escaping from the path. + // The % encoding will be added back when the url is encoded + // when the request is forwarded. + // This means that we will lose any unessecary escaping from the URL. + // Pratically this means that any distinction between '%2F' and '/' + // in the URL will be lost by the time it reaches the target. + path := req.URL.Path + path = "api" + path + log.WithFields(log.Fields{ + "path": path, + "url": targetURL, + "method": req.Method, + }).Print("proxying request") + newURL, err := url.Parse(targetURL + path) + if err != nil { + // We already checked that we can parse the URL + // So this shouldn't ever get hit. + panic(err) + } + // Copy the query parameters from the request. + newURL.RawQuery = req.URL.RawQuery + req.URL = newURL + }, + }, nil +} + +func main() { + flag.Usage = func() { + fmt.Fprintf(os.Stderr, usage, os.Args[0]) + flag.PrintDefaults() + } + + flag.Parse() + + if *federationAPIURL == "" { + flag.Usage() + fmt.Fprintln(os.Stderr, "no --federation-api-url specified.") + os.Exit(1) + } + + federationProxy, err := makeProxy(*federationAPIURL) + if err != nil { + panic(err) + } + + http.Handle("/", federationProxy) + + srv := &http.Server{ + Addr: *bindAddress, + ReadTimeout: 1 * time.Minute, // how long we wait for the client to send the entire request (after connection accept) + WriteTimeout: 5 * time.Minute, // how long the proxy has to write the full response + } + + fmt.Println("Proxying requests to:") + fmt.Println(" /* => ", *federationAPIURL+"/api/*") + fmt.Println("Listening on ", *bindAddress) + panic(srv.ListenAndServeTLS(*certFile, *keyFile)) +} diff --git a/src/github.com/matrix-org/dendrite/federationapi/routing/routing.go b/src/github.com/matrix-org/dendrite/federationapi/routing/routing.go index b5af66dc..e9461135 100644 --- a/src/github.com/matrix-org/dendrite/federationapi/routing/routing.go +++ b/src/github.com/matrix-org/dendrite/federationapi/routing/routing.go @@ -16,21 +16,34 @@ package routing import ( "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/federationapi/config" "github.com/matrix-org/dendrite/federationapi/readers" + "github.com/matrix-org/dendrite/federationapi/writers" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" "net/http" + "time" ) const ( - pathPrefixV2Keys = "/_matrix/key/v2" + pathPrefixV2Keys = "/_matrix/key/v2" + pathPrefixV1Federation = "/_matrix/federation/v1" ) // Setup registers HTTP handlers with the given ServeMux. -func Setup(servMux *http.ServeMux, cfg config.FederationAPI) { +func Setup( + servMux *http.ServeMux, + cfg config.FederationAPI, + query api.RoomserverQueryAPI, + producer *producers.RoomserverProducer, + keys gomatrixserverlib.KeyRing, +) { apiMux := mux.NewRouter() v2keysmux := apiMux.PathPrefix(pathPrefixV2Keys).Subrouter() + v1fedmux := apiMux.PathPrefix(pathPrefixV1Federation).Subrouter() localKeys := makeAPI("localkeys", func(req *http.Request) util.JSONResponse { return readers.LocalKeys(req, cfg) @@ -43,6 +56,17 @@ func Setup(servMux *http.ServeMux, cfg config.FederationAPI) { v2keysmux.Handle("/server/{keyID}", localKeys) v2keysmux.Handle("/server/", localKeys) + v1fedmux.Handle("/send/{txnID}/", makeAPI("send", + func(req *http.Request) util.JSONResponse { + vars := mux.Vars(req) + return writers.Send( + req, gomatrixserverlib.TransactionID(vars["txnID"]), + time.Now(), + cfg, query, producer, keys, + ) + }, + )) + servMux.Handle("/metrics", prometheus.Handler()) servMux.Handle("/api/", http.StripPrefix("/api", apiMux)) } diff --git a/src/github.com/matrix-org/dendrite/federationapi/writers/send.go b/src/github.com/matrix-org/dendrite/federationapi/writers/send.go new file mode 100644 index 00000000..da6ed34d --- /dev/null +++ b/src/github.com/matrix-org/dendrite/federationapi/writers/send.go @@ -0,0 +1,182 @@ +package writers + +import ( + "encoding/json" + "fmt" + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/producers" + "github.com/matrix-org/dendrite/federationapi/config" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "net/http" + "time" +) + +// Send implements /_matrix/federation/v1/send/{txnID} +func Send( + req *http.Request, + txnID gomatrixserverlib.TransactionID, + now time.Time, + cfg config.FederationAPI, + query api.RoomserverQueryAPI, + producer *producers.RoomserverProducer, + keys gomatrixserverlib.KeyRing, +) util.JSONResponse { + request, errResp := gomatrixserverlib.VerifyHTTPRequest(req, now, cfg.ServerName, keys) + if request == nil { + return errResp + } + + var content gomatrixserverlib.Transaction + if err := json.Unmarshal(request.Content(), &content); err != nil { + return util.JSONResponse{ + Code: 400, + JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), + } + } + + content.Origin = request.Origin() + content.TransactionID = txnID + content.Destination = cfg.ServerName + + resp, err := processTransaction(content, query, producer, keys) + if err != nil { + return httputil.LogThenError(req, err) + } + + return util.JSONResponse{ + Code: 200, + JSON: resp, + } +} + +func processTransaction( + t gomatrixserverlib.Transaction, + query api.RoomserverQueryAPI, + producer *producers.RoomserverProducer, + keys gomatrixserverlib.KeyRing, +) (*gomatrixserverlib.RespSend, error) { + // Check the event signatures + if err := gomatrixserverlib.VerifyEventSignatures(t.PDUs, keys); err != nil { + return nil, err + } + + // Process the events. + results := map[string]gomatrixserverlib.PDUResult{} + for _, e := range t.PDUs { + err := processEvent(e, query, producer) + if err != nil { + // If the error is due to the event itself being bad then we skip + // it and move onto the next event. We report an error so that the + // sender knows that we have skipped processing it. + // + // However if the event is due to a temporary failure in our server + // such as a database being unavailable then we should bail, and + // hope that the sender will retry when we are feeling better. + // + // It is uncertain what we should do if an event fails because + // we failed to fetch more information from the sending server. + // For example if a request to /state fails. + // If we skip the event then we risk missing the event until we + // receive another event referencing it. + // If we bail and stop processing then we risk wedging incoming + // transactions from that server forever. + switch err.(type) { + case unknownRoomError: + case *gomatrixserverlib.NotAllowed: + default: + // Any other error should be the result of a temporary error in + // our server so we should bail processing the transaction entirely. + return nil, err + } + results[e.EventID()] = gomatrixserverlib.PDUResult{err.Error()} + } else { + results[e.EventID()] = gomatrixserverlib.PDUResult{} + } + } + + // TODO: Process the EDUs. + + return &gomatrixserverlib.RespSend{PDUs: results}, nil +} + +type unknownRoomError string + +func (e unknownRoomError) Error() string { return fmt.Sprintf("unknown room %q", e) } + +func processEvent( + e gomatrixserverlib.Event, + query api.RoomserverQueryAPI, + producer *producers.RoomserverProducer, +) error { + refs := e.PrevEvents() + prevEventIDs := make([]string, len(refs)) + for i := range refs { + prevEventIDs[i] = refs[i].EventID + } + + // Fetch the state needed to authenticate the event. + needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{e}) + stateReq := api.QueryStateAfterEventsRequest{ + RoomID: e.RoomID(), + PrevEventIDs: prevEventIDs, + StateToFetch: needed.Tuples(), + } + var stateResp api.QueryStateAfterEventsResponse + if err := query.QueryStateAfterEvents(&stateReq, &stateResp); err != nil { + return err + } + + if !stateResp.RoomExists { + // TODO: When synapse receives a message for a room it is not in it + // asked the remote server for the state of the room so that it can + // check if the remote server knows of a join "m.room.member" event + // that this server is unaware of. + // However generally speaking we should reject events for rooms we + // aren't a member of. + return unknownRoomError(e.RoomID()) + } + + if !stateResp.PrevEventsExist { + // We are missing the previous events for this events. + // This means that there is a gap in our view of the history of the + // room. There two ways that we can handle such a gap: + // 1) We can fill in the gap using /get_missing_events + // 2) We can leave the gap and request the state of the room at + // this event from the remote server using either /state_ids + // or /state. + // Synapse will attempt to do 1 and if that fails or if the gap is + // too large then it will attempt 2. + // Synapse will use /state_ids if possible since ususally the state + // is largely unchanged and it is more efficient to fetch a list of + // event ids and then use /event to fetch the individual events. + // However not all version of synapse support /state_ids so you may + // need to fallback to /state. + // TODO: Attempt to fill in the gap using /get_missing_events + // TODO: Attempt to fetch the state using /state_ids and /events + // TODO: Attempt to fetch the state using /state + panic(fmt.Errorf("Receiving events with missing prev_events is no implemented")) + } + + // Check that the event is allowed by the state at the event. + authUsingState := gomatrixserverlib.NewAuthEvents(nil) + for i := range stateResp.StateEvents { + authUsingState.AddEvent(&stateResp.StateEvents[i]) + } + err := gomatrixserverlib.Allowed(e, &authUsingState) + if err != nil { + return err + } + + // TODO: Check that the roomserver has a copy of all of the auth_events. + // TODO: Check that the event is allowed by its auth_events. + + // pass the event to the roomserver + if err := producer.SendEvents([]gomatrixserverlib.Event{e}); err != nil { + return err + } + + return nil +} diff --git a/src/github.com/matrix-org/dendrite/roomserver/api/query.go b/src/github.com/matrix-org/dendrite/roomserver/api/query.go index 14afb3e8..6f26212d 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/api/query.go +++ b/src/github.com/matrix-org/dendrite/roomserver/api/query.go @@ -41,6 +41,7 @@ type QueryLatestEventsAndStateResponse struct { // The latest events in the room. LatestEvents []gomatrixserverlib.EventReference // The state events requested. + // This list will be in an arbitrary order. StateEvents []gomatrixserverlib.Event } @@ -65,9 +66,30 @@ type QueryStateAfterEventsResponse struct { // If some of previous events do not exist this will be false and StateEvents will be empty. PrevEventsExist bool // The state events requested. + // This list will be in an arbitrary order. StateEvents []gomatrixserverlib.Event } +// QueryEventsByIDRequest is a request to QueryEventsByID +type QueryEventsByIDRequest struct { + // The event IDs to look up. + EventIDs []string +} + +// QueryEventsByIDResponse is a response to QueryEventsByID +type QueryEventsByIDResponse struct { + // Copy of the request for debugging. + QueryEventsByIDRequest + // A list of events with the requested IDs. + // If the roomserver does not have a copy of a requested event + // then it will omit that event from the list. + // If the roomserver thinks it has a copy of the event, but + // fails to read it from the database then it will fail + // the entire request. + // This list will be in an arbitrary order. + Events []gomatrixserverlib.Event +} + // RoomserverQueryAPI is used to query information from the room server. type RoomserverQueryAPI interface { // Query the latest events and state for a room from the room server. @@ -81,6 +103,12 @@ type RoomserverQueryAPI interface { request *QueryStateAfterEventsRequest, response *QueryStateAfterEventsResponse, ) error + + // Query a list of events by event ID. + QueryEventsByID( + request *QueryEventsByIDRequest, + response *QueryEventsByIDResponse, + ) error } // RoomserverQueryLatestEventsAndStatePath is the HTTP path for the QueryLatestEventsAndState API. @@ -89,6 +117,9 @@ const RoomserverQueryLatestEventsAndStatePath = "/api/roomserver/QueryLatestEven // RoomserverQueryStateAfterEventsPath is the HTTP path for the QueryStateAfterEvents API. const RoomserverQueryStateAfterEventsPath = "/api/roomserver/QueryStateAfterEvents" +// RoomserverQueryEventsByIDPath is the HTTP path for the QueryEventsByID API. +const RoomserverQueryEventsByIDPath = "/api/roomserver/QueryEventsByID" + // NewRoomserverQueryAPIHTTP creates a RoomserverQueryAPI implemented by talking to a HTTP POST API. // If httpClient is nil then it uses the http.DefaultClient func NewRoomserverQueryAPIHTTP(roomserverURL string, httpClient *http.Client) RoomserverQueryAPI { @@ -121,6 +152,15 @@ func (h *httpRoomserverQueryAPI) QueryStateAfterEvents( return postJSON(h.httpClient, apiURL, request, response) } +// QueryEventsByID implements RoomserverQueryAPI +func (h *httpRoomserverQueryAPI) QueryEventsByID( + request *QueryEventsByIDRequest, + response *QueryEventsByIDResponse, +) error { + apiURL := h.roomserverURL + RoomserverQueryEventsByIDPath + return postJSON(h.httpClient, apiURL, request, response) +} + func postJSON(httpClient http.Client, apiURL string, request, response interface{}) error { jsonBytes, err := json.Marshal(request) if err != nil { diff --git a/src/github.com/matrix-org/dendrite/roomserver/query/query.go b/src/github.com/matrix-org/dendrite/roomserver/query/query.go index fedc9870..6f236e93 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/query/query.go +++ b/src/github.com/matrix-org/dendrite/roomserver/query/query.go @@ -35,6 +35,9 @@ type RoomserverQueryAPIDatabase interface { // Lookup event references for the latest events in the room and the current state snapshot. // Returns an error if there was a problem talking to the database. LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, error) + // Lookup the numeric IDs for a list of events. + // Returns an error if there was a problem talking to the database. + EventNIDs(eventIDs []string) (map[string]types.EventNID, error) } // RoomserverQueryAPI is an implementation of RoomserverQueryAPI @@ -46,7 +49,7 @@ type RoomserverQueryAPI struct { func (r *RoomserverQueryAPI) QueryLatestEventsAndState( request *api.QueryLatestEventsAndStateRequest, response *api.QueryLatestEventsAndStateResponse, -) (err error) { +) error { response.QueryLatestEventsAndStateRequest = *request roomNID, err := r.DB.RoomNID(request.RoomID) if err != nil { @@ -81,7 +84,7 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState( func (r *RoomserverQueryAPI) QueryStateAfterEvents( request *api.QueryStateAfterEventsRequest, response *api.QueryStateAfterEventsResponse, -) (err error) { +) error { response.QueryStateAfterEventsRequest = *request roomNID, err := r.DB.RoomNID(request.RoomID) if err != nil { @@ -115,12 +118,41 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents( return nil } +// QueryEventsByID implements api.RoomserverQueryAPI +func (r *RoomserverQueryAPI) QueryEventsByID( + request *api.QueryEventsByIDRequest, + response *api.QueryEventsByIDResponse, +) error { + response.QueryEventsByIDRequest = *request + + eventNIDMap, err := r.DB.EventNIDs(request.EventIDs) + if err != nil { + return err + } + + var eventNIDs []types.EventNID + for _, nid := range eventNIDMap { + eventNIDs = append(eventNIDs, nid) + } + + events, err := r.loadEvents(eventNIDs) + if err != nil { + return err + } + + response.Events = events + return nil +} + func (r *RoomserverQueryAPI) loadStateEvents(stateEntries []types.StateEntry) ([]gomatrixserverlib.Event, error) { eventNIDs := make([]types.EventNID, len(stateEntries)) for i := range stateEntries { eventNIDs[i] = stateEntries[i].EventNID } + return r.loadEvents(eventNIDs) +} +func (r *RoomserverQueryAPI) loadEvents(eventNIDs []types.EventNID) ([]gomatrixserverlib.Event, error) { stateEvents, err := r.DB.Events(eventNIDs) if err != nil { return nil, err @@ -163,4 +195,18 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) { return util.JSONResponse{Code: 200, JSON: &response} }), ) + servMux.Handle( + api.RoomserverQueryEventsByIDPath, + common.MakeAPI("query_events_by_id", func(req *http.Request) util.JSONResponse { + var request api.QueryEventsByIDRequest + var response api.QueryEventsByIDResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryEventsByID(&request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: 200, JSON: &response} + }), + ) } diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go b/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go index 696d35be..acaf43b6 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/events_table.go @@ -104,6 +104,9 @@ const bulkSelectEventReferenceSQL = "" + const bulkSelectEventIDSQL = "" + "SELECT event_nid, event_id FROM events WHERE event_nid = ANY($1)" +const bulkSelectEventNIDSQL = "" + + "SELECT event_id, event_nid FROM events WHERE event_id = ANY($1)" + type eventStatements struct { insertEventStmt *sql.Stmt selectEventStmt *sql.Stmt @@ -116,6 +119,7 @@ type eventStatements struct { bulkSelectStateAtEventAndReferenceStmt *sql.Stmt bulkSelectEventReferenceStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt + bulkSelectEventNIDStmt *sql.Stmt } func (s *eventStatements) prepare(db *sql.DB) (err error) { @@ -136,6 +140,7 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) { {&s.bulkSelectStateAtEventAndReferenceStmt, bulkSelectStateAtEventAndReferenceSQL}, {&s.bulkSelectEventReferenceStmt, bulkSelectEventReferenceSQL}, {&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL}, + {&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL}, }.prepare(db) } @@ -321,6 +326,26 @@ func (s *eventStatements) bulkSelectEventID(eventNIDs []types.EventNID) (map[typ return results, nil } +// bulkSelectEventNIDs returns a map from string event ID to numeric event ID. +// If an event ID is not in the database then it is omitted from the map. +func (s *eventStatements) bulkSelectEventNID(eventIDs []string) (map[string]types.EventNID, error) { + rows, err := s.bulkSelectEventNIDStmt.Query(pq.StringArray(eventIDs)) + if err != nil { + return nil, err + } + defer rows.Close() + results := make(map[string]types.EventNID, len(eventIDs)) + for rows.Next() { + var eventID string + var eventNID int64 + if err = rows.Scan(&eventID, &eventNID); err != nil { + return nil, err + } + results[eventID] = types.EventNID(eventNID) + } + return results, nil +} + func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array { nids := make([]int64, len(eventNIDs)) for i := range eventNIDs { diff --git a/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go b/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go index 8e527d63..b9b5eb1c 100644 --- a/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go +++ b/src/github.com/matrix-org/dendrite/roomserver/storage/storage.go @@ -170,6 +170,11 @@ func (d *Database) EventStateKeyNIDs(eventStateKeys []string) (map[string]types. return d.statements.bulkSelectEventStateKeyNID(eventStateKeys) } +// EventNIDs implements query.RoomQueryDatabase +func (d *Database) EventNIDs(eventIDs []string) (map[string]types.EventNID, error) { + return d.statements.bulkSelectEventNID(eventIDs) +} + // Events implements input.EventDatabase func (d *Database) Events(eventNIDs []types.EventNID) ([]types.Event, error) { eventJSONs, err := d.statements.bulkSelectEventJSON(eventNIDs)