From ddf1c8adf1fd1441b76834df479d1ab5a132de88 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 17 Jun 2020 17:41:45 +0100 Subject: [PATCH 1/3] Hacks for supporting Riot iOS (#1148) * Join room body is optional * Support deprecated login by user/password * Implement dummy key upload endpoint * Make a very determinate end to /messages if we hit the create event in back-pagination * Linting --- clientapi/routing/joinroom.go | 4 +-- clientapi/routing/login.go | 67 +++++++++++++++++++++++------------ keyserver/routing/routing.go | 10 ++++++ syncapi/routing/messages.go | 45 ++++++++++++++--------- 4 files changed, 84 insertions(+), 42 deletions(-) diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index e190beef..3871e4d4 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -43,9 +43,7 @@ func JoinRoomByIDOrAlias( // If content was provided in the request then incude that // in the request. It'll get used as a part of the membership // event content. - if err := httputil.UnmarshalJSONRequest(req, &joinReq.Content); err != nil { - return *err - } + _ = httputil.UnmarshalJSONRequest(req, &joinReq.Content) // Work out our localpart for the client profile request. localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 1b894a56..dc0180da 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -47,6 +47,7 @@ type loginIdentifier struct { type passwordRequest struct { Identifier loginIdentifier `json:"identifier"` + User string `json:"user"` // deprecated in favour of identifier Password string `json:"password"` // Both DeviceID and InitialDisplayName can be omitted, or empty strings ("") // Thus a pointer is needed to differentiate between the two @@ -81,6 +82,7 @@ func Login( } else if req.Method == http.MethodPost { var r passwordRequest var acc *api.Account + var errJSON *util.JSONResponse resErr := httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { return *resErr @@ -93,30 +95,22 @@ func Login( JSON: jsonerror.BadJSON("'user' must be supplied."), } } - - util.GetLogger(req.Context()).WithField("user", r.Identifier.User).Info("Processing login request") - - localpart, err := userutil.ParseUsernameParam(r.Identifier.User, &cfg.Matrix.ServerName) - if err != nil { - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername(err.Error()), - } - } - - acc, err = accountDB.GetAccountByPassword(req.Context(), localpart, r.Password) - if err != nil { - // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows - // but that would leak the existence of the user. - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("username or password was incorrect, or the account does not exist"), - } + acc, errJSON = r.processUsernamePasswordLoginRequest(req, accountDB, cfg, r.Identifier.User) + if errJSON != nil { + return *errJSON } default: - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("login identifier '" + r.Identifier.Type + "' not supported"), + // TODO: The below behaviour is deprecated but without it Riot iOS won't log in + if r.User != "" { + acc, errJSON = r.processUsernamePasswordLoginRequest(req, accountDB, cfg, r.User) + if errJSON != nil { + return *errJSON + } + } else { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("login identifier '" + r.Identifier.Type + "' not supported"), + } } } @@ -163,3 +157,32 @@ func getDevice( ) return } + +func (r *passwordRequest) processUsernamePasswordLoginRequest( + req *http.Request, accountDB accounts.Database, + cfg *config.Dendrite, username string, +) (acc *api.Account, errJSON *util.JSONResponse) { + util.GetLogger(req.Context()).WithField("user", username).Info("Processing login request") + + localpart, err := userutil.ParseUsernameParam(username, &cfg.Matrix.ServerName) + if err != nil { + errJSON = &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(err.Error()), + } + return + } + + acc, err = accountDB.GetAccountByPassword(req.Context(), localpart, r.Password) + if err != nil { + // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows + // but that would leak the existence of the user. + errJSON = &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("username or password was incorrect, or the account does not exist"), + } + return + } + + return +} diff --git a/keyserver/routing/routing.go b/keyserver/routing/routing.go index c09031d8..dba43528 100644 --- a/keyserver/routing/routing.go +++ b/keyserver/routing/routing.go @@ -36,9 +36,19 @@ func Setup( publicAPIMux *mux.Router, cfg *config.Dendrite, userAPI userapi.UserInternalAPI, ) { r0mux := publicAPIMux.PathPrefix(pathPrefixR0).Subrouter() + r0mux.Handle("/keys/query", httputil.MakeAuthAPI("queryKeys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return QueryKeys(req) }), ).Methods(http.MethodPost, http.MethodOptions) + + r0mux.Handle("/keys/upload/{keyID}", + httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return util.JSONResponse{ + Code: 200, + JSON: map[string]interface{}{}, + } + }), + ).Methods(http.MethodPost, http.MethodOptions) } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index de5429db..15add1b4 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -158,6 +158,7 @@ func OnIncomingMessagesRequest( util.GetLogger(req.Context()).WithError(err).Error("mreq.retrieveEvents failed") return jsonerror.InternalServerError() } + util.GetLogger(req.Context()).WithFields(logrus.Fields{ "from": from.String(), "to": to.String(), @@ -246,6 +247,12 @@ func (r *messagesReq) retrieveEvents() ( // change the way topological positions are defined (as depth isn't the most // reliable way to define it), it would be easier and less troublesome to // only have to change it in one place, i.e. the database. + start, end, err = r.getStartEnd(events) + + return clientEvents, start, end, err +} + +func (r *messagesReq) getStartEnd(events []gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { start, err = r.db.EventPositionInTopology( r.ctx, events[0].EventID(), ) @@ -253,24 +260,28 @@ func (r *messagesReq) retrieveEvents() ( err = fmt.Errorf("EventPositionInTopology: for start event %s: %w", events[0].EventID(), err) return } - end, err = r.db.EventPositionInTopology( - r.ctx, events[len(events)-1].EventID(), - ) - if err != nil { - err = fmt.Errorf("EventPositionInTopology: for end event %s: %w", events[len(events)-1].EventID(), err) - return + if r.backwardOrdering && events[len(events)-1].Type() == gomatrixserverlib.MRoomCreate { + // We've hit the beginning of the room so there's really nowhere else + // to go. This seems to fix Riot iOS from looping on /messages endlessly. + end = types.NewTopologyToken(0, 0) + } else { + end, err = r.db.EventPositionInTopology( + r.ctx, events[len(events)-1].EventID(), + ) + if err != nil { + err = fmt.Errorf("EventPositionInTopology: for end event %s: %w", events[len(events)-1].EventID(), err) + return + } + if r.backwardOrdering { + // A stream/topological position is a cursor located between two events. + // While they are identified in the code by the event on their right (if + // we consider a left to right chronological order), tokens need to refer + // to them by the event on their left, therefore we need to decrement the + // end position we send in the response if we're going backward. + end.Decrement() + } } - - if r.backwardOrdering { - // A stream/topological position is a cursor located between two events. - // While they are identified in the code by the event on their right (if - // we consider a left to right chronological order), tokens need to refer - // to them by the event on their left, therefore we need to decrement the - // end position we send in the response if we're going backward. - end.Decrement() - } - - return clientEvents, start, end, err + return } // handleEmptyEventsSlice handles the case where the initial request to the From 3547a1768c36626c672e5c7834f297496f568b2f Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 18 Jun 2020 13:48:47 +0100 Subject: [PATCH 2/3] Fix embed Riot Web into Yggdrasil demo --- cmd/dendrite-demo-yggdrasil/embed/embed_other.go | 4 +++- cmd/dendrite-demo-yggdrasil/embed/embed_riotweb.go | 7 ++++--- cmd/dendrite-demo-yggdrasil/main.go | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/cmd/dendrite-demo-yggdrasil/embed/embed_other.go b/cmd/dendrite-demo-yggdrasil/embed/embed_other.go index a9108fad..59888114 100644 --- a/cmd/dendrite-demo-yggdrasil/embed/embed_other.go +++ b/cmd/dendrite-demo-yggdrasil/embed/embed_other.go @@ -2,6 +2,8 @@ package embed -func Embed(_ int, _ string) { +import "github.com/gorilla/mux" + +func Embed(_ *mux.Router, _ int, _ string) { } diff --git a/cmd/dendrite-demo-yggdrasil/embed/embed_riotweb.go b/cmd/dendrite-demo-yggdrasil/embed/embed_riotweb.go index 360d0bc5..a9e04a31 100644 --- a/cmd/dendrite-demo-yggdrasil/embed/embed_riotweb.go +++ b/cmd/dendrite-demo-yggdrasil/embed/embed_riotweb.go @@ -7,19 +7,20 @@ import ( "io" "net/http" + "github.com/gorilla/mux" "github.com/tidwall/sjson" ) // From within the Riot Web directory: // go run github.com/mjibson/esc -o /path/to/dendrite/internal/embed/fs_riotweb.go -private -pkg embed . -func Embed(listenPort int, serverName string) { +func Embed(rootMux *mux.Router, listenPort int, serverName string) { url := fmt.Sprintf("http://localhost:%d", listenPort) embeddedFS := _escFS(false) embeddedServ := http.FileServer(embeddedFS) - http.DefaultServeMux.Handle("/", embeddedServ) - http.DefaultServeMux.HandleFunc("/config.json", func(w http.ResponseWriter, _ *http.Request) { + rootMux.Handle("/", embeddedServ) + rootMux.HandleFunc("/config.json", func(w http.ResponseWriter, _ *http.Request) { configFile, err := embeddedFS.Open("/config.sample.json") if err != nil { w.WriteHeader(500) diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index be15da47..328aa062 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -154,7 +154,7 @@ func main() { logrus.WithError(err).Panicf("failed to connect to public rooms db") } - embed.Embed(*instancePort, "Yggdrasil Demo") + embed.Embed(base.BaseMux, *instancePort, "Yggdrasil Demo") monolith := setup.Monolith{ Config: base.Cfg, From dc0bac85d5bad933d32ee63f8bc1aef6348ca6e9 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 18 Jun 2020 18:36:03 +0100 Subject: [PATCH 3/3] Refactor account data (#1150) * Refactor account data * Tweak database fetching * Tweaks * Restore syncProducer notification * Various tweaks, update tag behaviour * Fix initial sync --- clientapi/routing/account_data.go | 55 ++++---- clientapi/routing/room_tagging.go | 127 +++++++----------- clientapi/routing/routing.go | 14 +- syncapi/sync/requestpool.go | 75 +++++++---- userapi/api/api.go | 27 ++-- userapi/internal/api.go | 31 ++++- userapi/inthttp/client.go | 10 ++ userapi/storage/accounts/interface.go | 7 +- .../accounts/postgres/account_data_table.go | 48 +++---- userapi/storage/accounts/postgres/storage.go | 13 +- .../accounts/sqlite3/account_data_table.go | 50 +++---- userapi/storage/accounts/sqlite3/storage.go | 13 +- 12 files changed, 248 insertions(+), 222 deletions(-) diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index 68e0dc5d..d5fafedb 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -16,21 +16,20 @@ package routing import ( "encoding/json" + "fmt" "io/ioutil" "net/http" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) // GetAccountData implements GET /user/{userId}/[rooms/{roomid}/]account_data/{type} func GetAccountData( - req *http.Request, accountDB accounts.Database, device *api.Device, + req *http.Request, userAPI api.UserInternalAPI, device *api.Device, userID string, roomID string, dataType string, ) util.JSONResponse { if userID != device.UserID { @@ -40,18 +39,28 @@ func GetAccountData( } } - localpart, _, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() + dataReq := api.QueryAccountDataRequest{ + UserID: userID, + DataType: dataType, + RoomID: roomID, + } + dataRes := api.QueryAccountDataResponse{} + if err := userAPI.QueryAccountData(req.Context(), &dataReq, &dataRes); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed") + return util.ErrorResponse(fmt.Errorf("userAPI.QueryAccountData: %w", err)) } - if data, err := accountDB.GetAccountDataByType( - req.Context(), localpart, roomID, dataType, - ); err == nil { + var data json.RawMessage + var ok bool + if roomID != "" { + data, ok = dataRes.RoomAccountData[roomID][dataType] + } else { + data, ok = dataRes.GlobalAccountData[dataType] + } + if ok { return util.JSONResponse{ Code: http.StatusOK, - JSON: data.Content, + JSON: data, } } @@ -63,7 +72,7 @@ func GetAccountData( // SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type} func SaveAccountData( - req *http.Request, accountDB accounts.Database, device *api.Device, + req *http.Request, userAPI api.UserInternalAPI, device *api.Device, userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer, ) util.JSONResponse { if userID != device.UserID { @@ -73,12 +82,6 @@ func SaveAccountData( } } - localpart, _, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - defer req.Body.Close() // nolint: errcheck if req.Body == http.NoBody { @@ -101,13 +104,19 @@ func SaveAccountData( } } - if err := accountDB.SaveAccountData( - req.Context(), localpart, roomID, dataType, string(body), - ); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("accountDB.SaveAccountData failed") - return jsonerror.InternalServerError() + dataReq := api.InputAccountDataRequest{ + UserID: userID, + DataType: dataType, + RoomID: roomID, + AccountData: json.RawMessage(body), + } + dataRes := api.InputAccountDataResponse{} + if err := userAPI.InputAccountData(req.Context(), &dataReq, &dataRes); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed") + return util.ErrorResponse(err) } + // TODO: user API should do this since it's account data if err := syncProducer.SendData(userID, roomID, dataType); err != nil { util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed") return jsonerror.InternalServerError() diff --git a/clientapi/routing/room_tagging.go b/clientapi/routing/room_tagging.go index b1cfcca8..c683cc94 100644 --- a/clientapi/routing/room_tagging.go +++ b/clientapi/routing/room_tagging.go @@ -24,23 +24,14 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrix" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) -// newTag creates and returns a new gomatrix.TagContent -func newTag() gomatrix.TagContent { - return gomatrix.TagContent{ - Tags: make(map[string]gomatrix.TagProperties), - } -} - // GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags func GetTags( req *http.Request, - accountDB accounts.Database, + userAPI api.UserInternalAPI, device *api.Device, userID string, roomID string, @@ -54,22 +45,15 @@ func GetTags( } } - _, data, err := obtainSavedTags(req, userID, roomID, accountDB) + tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") return jsonerror.InternalServerError() } - if data == nil { - return util.JSONResponse{ - Code: http.StatusOK, - JSON: struct{}{}, - } - } - return util.JSONResponse{ Code: http.StatusOK, - JSON: data.Content, + JSON: tagContent, } } @@ -78,7 +62,7 @@ func GetTags( // the tag to the "map" and saving the new "map" to the DB func PutTag( req *http.Request, - accountDB accounts.Database, + userAPI api.UserInternalAPI, device *api.Device, userID string, roomID string, @@ -98,34 +82,25 @@ func PutTag( return *reqErr } - localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB) + tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") return jsonerror.InternalServerError() } - var tagContent gomatrix.TagContent - if data != nil { - if err = json.Unmarshal(data.Content, &tagContent); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed") - return jsonerror.InternalServerError() - } - } else { - tagContent = newTag() + if tagContent.Tags == nil { + tagContent.Tags = make(map[string]gomatrix.TagProperties) } tagContent.Tags[tag] = properties - if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil { + + if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil { util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") return jsonerror.InternalServerError() } - // Send data to syncProducer in order to inform clients of changes - // Run in a goroutine in order to prevent blocking the tag request response - go func() { - if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { - logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") - } - }() + if err = syncProducer.SendData(userID, roomID, "m.tag"); err != nil { + logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") + } return util.JSONResponse{ Code: http.StatusOK, @@ -138,7 +113,7 @@ func PutTag( // the "map" and then saving the new "map" in the DB func DeleteTag( req *http.Request, - accountDB accounts.Database, + userAPI api.UserInternalAPI, device *api.Device, userID string, roomID string, @@ -153,28 +128,12 @@ func DeleteTag( } } - localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB) + tagContent, err := obtainSavedTags(req, userID, roomID, userAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed") return jsonerror.InternalServerError() } - // If there are no tags in the database, exit - if data == nil { - // Spec only defines 200 responses for this endpoint so we don't return anything else. - return util.JSONResponse{ - Code: http.StatusOK, - JSON: struct{}{}, - } - } - - var tagContent gomatrix.TagContent - err = json.Unmarshal(data.Content, &tagContent) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed") - return jsonerror.InternalServerError() - } - // Check whether the tag to be deleted exists if _, ok := tagContent.Tags[tag]; ok { delete(tagContent.Tags, tag) @@ -185,18 +144,16 @@ func DeleteTag( JSON: struct{}{}, } } - if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil { + + if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil { util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed") return jsonerror.InternalServerError() } - // Send data to syncProducer in order to inform clients of changes - // Run in a goroutine in order to prevent blocking the tag request response - go func() { - if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { - logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") - } - }() + // TODO: user API should do this since it's account data + if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { + logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") + } return util.JSONResponse{ Code: http.StatusOK, @@ -210,32 +167,46 @@ func obtainSavedTags( req *http.Request, userID string, roomID string, - accountDB accounts.Database, -) (string, *gomatrixserverlib.ClientEvent, error) { - localpart, _, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - return "", nil, err + userAPI api.UserInternalAPI, +) (tags gomatrix.TagContent, err error) { + dataReq := api.QueryAccountDataRequest{ + UserID: userID, + RoomID: roomID, + DataType: "m.tag", } - - data, err := accountDB.GetAccountDataByType( - req.Context(), localpart, roomID, "m.tag", - ) - - return localpart, data, err + dataRes := api.QueryAccountDataResponse{} + err = userAPI.QueryAccountData(req.Context(), &dataReq, &dataRes) + if err != nil { + return + } + data, ok := dataRes.RoomAccountData[roomID]["m.tag"] + if !ok { + return + } + if err = json.Unmarshal(data, &tags); err != nil { + return + } + return tags, nil } // saveTagData saves the provided tag data into the database func saveTagData( req *http.Request, - localpart string, + userID string, roomID string, - accountDB accounts.Database, + userAPI api.UserInternalAPI, Tag gomatrix.TagContent, ) error { newTagData, err := json.Marshal(Tag) if err != nil { return err } - - return accountDB.SaveAccountData(req.Context(), localpart, roomID, "m.tag", string(newTagData)) + dataReq := api.InputAccountDataRequest{ + UserID: userID, + RoomID: roomID, + DataType: "m.tag", + AccountData: json.RawMessage(newTagData), + } + dataRes := api.InputAccountDataResponse{} + return userAPI.InputAccountData(req.Context(), &dataReq, &dataRes) } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 41c7fb18..e91b07ac 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -476,7 +476,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SaveAccountData(req, accountDB, device, vars["userID"], "", vars["type"], syncProducer) + return SaveAccountData(req, userAPI, device, vars["userID"], "", vars["type"], syncProducer) }), ).Methods(http.MethodPut, http.MethodOptions) @@ -486,7 +486,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return SaveAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"], syncProducer) + return SaveAccountData(req, userAPI, device, vars["userID"], vars["roomID"], vars["type"], syncProducer) }), ).Methods(http.MethodPut, http.MethodOptions) @@ -496,7 +496,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return GetAccountData(req, accountDB, device, vars["userID"], "", vars["type"]) + return GetAccountData(req, userAPI, device, vars["userID"], "", vars["type"]) }), ).Methods(http.MethodGet) @@ -506,7 +506,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return GetAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"]) + return GetAccountData(req, userAPI, device, vars["userID"], vars["roomID"], vars["type"]) }), ).Methods(http.MethodGet) @@ -604,7 +604,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return GetTags(req, accountDB, device, vars["userId"], vars["roomId"], syncProducer) + return GetTags(req, userAPI, device, vars["userId"], vars["roomId"], syncProducer) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -614,7 +614,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return PutTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) + return PutTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) }), ).Methods(http.MethodPut, http.MethodOptions) @@ -624,7 +624,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return DeleteTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) + return DeleteTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer) }), ).Methods(http.MethodDelete, http.MethodOptions) diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 26b925ea..8d51689e 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -205,22 +205,34 @@ func (rp *RequestPool) appendAccountData( if req.since == nil { // If this is the initial sync, we don't need to check if a data has // already been sent. Instead, we send the whole batch. - var res userapi.QueryAccountDataResponse - err := rp.userAPI.QueryAccountData(req.ctx, &userapi.QueryAccountDataRequest{ + dataReq := &userapi.QueryAccountDataRequest{ UserID: userID, - }, &res) - if err != nil { + } + dataRes := &userapi.QueryAccountDataResponse{} + if err := rp.userAPI.QueryAccountData(req.ctx, dataReq, dataRes); err != nil { return nil, err } - data.AccountData.Events = res.GlobalAccountData - + for datatype, databody := range dataRes.GlobalAccountData { + data.AccountData.Events = append( + data.AccountData.Events, + gomatrixserverlib.ClientEvent{ + Type: datatype, + Content: gomatrixserverlib.RawJSON(databody), + }, + ) + } for r, j := range data.Rooms.Join { - if len(res.RoomAccountData[r]) > 0 { - j.AccountData.Events = res.RoomAccountData[r] + for datatype, databody := range dataRes.RoomAccountData[r] { + j.AccountData.Events = append( + j.AccountData.Events, + gomatrixserverlib.ClientEvent{ + Type: datatype, + Content: gomatrixserverlib.RawJSON(databody), + }, + ) data.Rooms.Join[r] = j } } - return data, nil } @@ -249,33 +261,42 @@ func (rp *RequestPool) appendAccountData( // Iterate over the rooms for roomID, dataTypes := range dataTypes { - events := []gomatrixserverlib.ClientEvent{} // Request the missing data from the database for _, dataType := range dataTypes { - var res userapi.QueryAccountDataResponse - err = rp.userAPI.QueryAccountData(req.ctx, &userapi.QueryAccountDataRequest{ + dataReq := userapi.QueryAccountDataRequest{ UserID: userID, RoomID: roomID, DataType: dataType, - }, &res) + } + dataRes := userapi.QueryAccountDataResponse{} + err = rp.userAPI.QueryAccountData(req.ctx, &dataReq, &dataRes) if err != nil { - return nil, err + continue } - if len(res.RoomAccountData[roomID]) > 0 { - events = append(events, res.RoomAccountData[roomID]...) - } else if len(res.GlobalAccountData) > 0 { - events = append(events, res.GlobalAccountData...) + if roomID == "" { + if globalData, ok := dataRes.GlobalAccountData[dataType]; ok { + data.AccountData.Events = append( + data.AccountData.Events, + gomatrixserverlib.ClientEvent{ + Type: dataType, + Content: gomatrixserverlib.RawJSON(globalData), + }, + ) + } + } else { + if roomData, ok := dataRes.RoomAccountData[roomID][dataType]; ok { + joinData := data.Rooms.Join[roomID] + joinData.AccountData.Events = append( + joinData.AccountData.Events, + gomatrixserverlib.ClientEvent{ + Type: dataType, + Content: gomatrixserverlib.RawJSON(roomData), + }, + ) + data.Rooms.Join[roomID] = joinData + } } } - - // Append the data to the response - if len(roomID) > 0 { - jr := data.Rooms.Join[roomID] - jr.AccountData.Events = events - data.Rooms.Join[roomID] = jr - } else { - data.AccountData.Events = events - } } return data, nil diff --git a/userapi/api/api.go b/userapi/api/api.go index c953a5ba..a80adf2d 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -16,12 +16,14 @@ package api import ( "context" + "encoding/json" "github.com/matrix-org/gomatrixserverlib" ) // UserInternalAPI is the internal API for information about users and devices. type UserInternalAPI interface { + InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error @@ -30,6 +32,18 @@ type UserInternalAPI interface { QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error } +// InputAccountDataRequest is the request for InputAccountData +type InputAccountDataRequest struct { + UserID string // required: the user to set account data for + RoomID string // optional: the room to associate the account data with + DataType string // optional: the data type of the data + AccountData json.RawMessage // required: the message content +} + +// InputAccountDataResponse is the response for InputAccountData +type InputAccountDataResponse struct { +} + // QueryAccessTokenRequest is the request for QueryAccessToken type QueryAccessTokenRequest struct { AccessToken string @@ -46,18 +60,15 @@ type QueryAccessTokenResponse struct { // QueryAccountDataRequest is the request for QueryAccountData type QueryAccountDataRequest struct { - UserID string // required: the user to get account data for. - // TODO: This is a terribly confusing API shape :/ - DataType string // optional: if specified returns only a single event matching this data type. - // optional: Only used if DataType is set. If blank returns global account data matching the data type. - // If set, returns only room account data matching this data type. - RoomID string + UserID string // required: the user to get account data for. + RoomID string // optional: the room ID, or global account data if not specified. + DataType string // optional: the data type, or all types if not specified. } // QueryAccountDataResponse is the response for QueryAccountData type QueryAccountDataResponse struct { - GlobalAccountData []gomatrixserverlib.ClientEvent - RoomAccountData map[string][]gomatrixserverlib.ClientEvent + GlobalAccountData map[string]json.RawMessage // type -> data + RoomAccountData map[string]map[string]json.RawMessage // room -> type -> data } // QueryDevicesRequest is the request for QueryDevices diff --git a/userapi/internal/api.go b/userapi/internal/api.go index ae021f57..b081eca4 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -17,6 +17,7 @@ package internal import ( "context" "database/sql" + "encoding/json" "errors" "fmt" @@ -38,6 +39,20 @@ type UserInternalAPI struct { AppServices []config.ApplicationService } +func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { + local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) + if err != nil { + return err + } + if domain != a.ServerName { + return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) + } + if req.DataType == "" { + return fmt.Errorf("data type must not be empty") + } + return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData) +} + func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { if req.AccountType == api.AccountTypeGuest { acc, err := a.AccountDB.CreateGuestAccount(ctx) @@ -130,17 +145,21 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc return fmt.Errorf("cannot query account data of remote users: got %s want %s", domain, a.ServerName) } if req.DataType != "" { - var event *gomatrixserverlib.ClientEvent - event, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) + var data json.RawMessage + data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) if err != nil { return err } - if event != nil { + res.RoomAccountData = make(map[string]map[string]json.RawMessage) + res.GlobalAccountData = make(map[string]json.RawMessage) + if data != nil { if req.RoomID != "" { - res.RoomAccountData = make(map[string][]gomatrixserverlib.ClientEvent) - res.RoomAccountData[req.RoomID] = []gomatrixserverlib.ClientEvent{*event} + if _, ok := res.RoomAccountData[req.RoomID]; !ok { + res.RoomAccountData[req.RoomID] = make(map[string]json.RawMessage) + } + res.RoomAccountData[req.RoomID][req.DataType] = data } else { - res.GlobalAccountData = append(res.GlobalAccountData, *event) + res.GlobalAccountData[req.DataType] = data } } return nil diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 0e9628c5..4ab0d690 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -26,6 +26,8 @@ import ( // HTTP paths for the internal HTTP APIs const ( + InputAccountDataPath = "/userapi/inputAccountData" + PerformDeviceCreationPath = "/userapi/performDeviceCreation" PerformAccountCreationPath = "/userapi/performAccountCreation" @@ -55,6 +57,14 @@ type httpUserInternalAPI struct { httpClient *http.Client } +func (h *httpUserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "InputAccountData") + defer span.Finish() + + apiURL := h.apiURL + InputAccountDataPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) +} + func (h *httpUserInternalAPI) PerformAccountCreation( ctx context.Context, request *api.PerformAccountCreationRequest, diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go index 13e3e289..c6692879 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -16,6 +16,7 @@ package accounts import ( "context" + "encoding/json" "errors" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -39,13 +40,13 @@ type Database interface { GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error) GetRoomIDsByLocalPart(ctx context.Context, localpart string) ([]string, error) GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error) - SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error - GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error) + SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error + GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) // GetAccountDataByType returns account data matching a given // localpart, room ID and type. // If no account data could be found, returns nil // Returns an error if there was an issue with the retrieval - GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data *gomatrixserverlib.ClientEvent, err error) + GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error) GetNewNumericLocalpart(ctx context.Context) (int64, error) SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error) RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) diff --git a/userapi/storage/accounts/postgres/account_data_table.go b/userapi/storage/accounts/postgres/account_data_table.go index 2f16c5c0..90c79e87 100644 --- a/userapi/storage/accounts/postgres/account_data_table.go +++ b/userapi/storage/accounts/postgres/account_data_table.go @@ -17,9 +17,9 @@ package postgres import ( "context" "database/sql" + "encoding/json" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/gomatrixserverlib" ) const accountDataSchema = ` @@ -73,7 +73,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { } func (s *accountDataStatements) insertAccountData( - ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string, + ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ) (err error) { stmt := txn.Stmt(s.insertAccountDataStmt) _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content) @@ -83,18 +83,18 @@ func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) selectAccountData( ctx context.Context, localpart string, ) ( - global []gomatrixserverlib.ClientEvent, - rooms map[string][]gomatrixserverlib.ClientEvent, - err error, + /* global */ map[string]json.RawMessage, + /* rooms */ map[string]map[string]json.RawMessage, + error, ) { rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) if err != nil { - return + return nil, nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectAccountData: rows.close() failed") - global = []gomatrixserverlib.ClientEvent{} - rooms = make(map[string][]gomatrixserverlib.ClientEvent) + global := map[string]json.RawMessage{} + rooms := map[string]map[string]json.RawMessage{} for rows.Next() { var roomID string @@ -102,41 +102,33 @@ func (s *accountDataStatements) selectAccountData( var content []byte if err = rows.Scan(&roomID, &dataType, &content); err != nil { - return + return nil, nil, err } - ac := gomatrixserverlib.ClientEvent{ - Type: dataType, - Content: content, - } - - if len(roomID) > 0 { - rooms[roomID] = append(rooms[roomID], ac) + if roomID != "" { + if _, ok := rooms[roomID]; !ok { + rooms[roomID] = map[string]json.RawMessage{} + } + rooms[roomID][dataType] = content } else { - global = append(global, ac) + global[dataType] = content } } + return global, rooms, rows.Err() } func (s *accountDataStatements) selectAccountDataByType( ctx context.Context, localpart, roomID, dataType string, -) (data *gomatrixserverlib.ClientEvent, err error) { +) (data json.RawMessage, err error) { + var bytes []byte stmt := s.selectAccountDataByTypeStmt - var content []byte - - if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil { + if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil { if err == sql.ErrNoRows { return nil, nil } - return } - - data = &gomatrixserverlib.ClientEvent{ - Type: dataType, - Content: content, - } - + data = json.RawMessage(bytes) return } diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index 2b88cb70..e5509980 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -17,6 +17,7 @@ package postgres import ( "context" "database/sql" + "encoding/json" "errors" "strconv" @@ -169,7 +170,7 @@ func (d *Database) createAccount( return nil, err } - if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{ + if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ "global": { "content": [], "override": [], @@ -177,7 +178,7 @@ func (d *Database) createAccount( "sender": [], "underride": [] } - }`); err != nil { + }`)); err != nil { return nil, err } return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) @@ -295,7 +296,7 @@ func (d *Database) newMembership( // update the corresponding row with the new content // Returns a SQL error if there was an issue with the insertion/update func (d *Database) SaveAccountData( - ctx context.Context, localpart, roomID, dataType, content string, + ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ) error { return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) @@ -306,8 +307,8 @@ func (d *Database) SaveAccountData( // If no account data could be found, returns an empty arrays // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountData(ctx context.Context, localpart string) ( - global []gomatrixserverlib.ClientEvent, - rooms map[string][]gomatrixserverlib.ClientEvent, + global map[string]json.RawMessage, + rooms map[string]map[string]json.RawMessage, err error, ) { return d.accountDatas.selectAccountData(ctx, localpart) @@ -319,7 +320,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) ( // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountDataByType( ctx context.Context, localpart, roomID, dataType string, -) (data *gomatrixserverlib.ClientEvent, err error) { +) (data json.RawMessage, err error) { return d.accountDatas.selectAccountDataByType( ctx, localpart, roomID, dataType, ) diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go index b6bb6361..d048dbd1 100644 --- a/userapi/storage/accounts/sqlite3/account_data_table.go +++ b/userapi/storage/accounts/sqlite3/account_data_table.go @@ -17,8 +17,7 @@ package sqlite3 import ( "context" "database/sql" - - "github.com/matrix-org/gomatrixserverlib" + "encoding/json" ) const accountDataSchema = ` @@ -72,7 +71,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { } func (s *accountDataStatements) insertAccountData( - ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string, + ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ) (err error) { _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) return @@ -81,17 +80,17 @@ func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) selectAccountData( ctx context.Context, localpart string, ) ( - global []gomatrixserverlib.ClientEvent, - rooms map[string][]gomatrixserverlib.ClientEvent, - err error, + /* global */ map[string]json.RawMessage, + /* rooms */ map[string]map[string]json.RawMessage, + error, ) { rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) if err != nil { - return + return nil, nil, err } - global = []gomatrixserverlib.ClientEvent{} - rooms = make(map[string][]gomatrixserverlib.ClientEvent) + global := map[string]json.RawMessage{} + rooms := map[string]map[string]json.RawMessage{} for rows.Next() { var roomID string @@ -99,42 +98,33 @@ func (s *accountDataStatements) selectAccountData( var content []byte if err = rows.Scan(&roomID, &dataType, &content); err != nil { - return + return nil, nil, err } - ac := gomatrixserverlib.ClientEvent{ - Type: dataType, - Content: content, - } - - if len(roomID) > 0 { - rooms[roomID] = append(rooms[roomID], ac) + if roomID != "" { + if _, ok := rooms[roomID]; !ok { + rooms[roomID] = map[string]json.RawMessage{} + } + rooms[roomID][dataType] = content } else { - global = append(global, ac) + global[dataType] = content } } - return + return global, rooms, nil } func (s *accountDataStatements) selectAccountDataByType( ctx context.Context, localpart, roomID, dataType string, -) (data *gomatrixserverlib.ClientEvent, err error) { +) (data json.RawMessage, err error) { + var bytes []byte stmt := s.selectAccountDataByTypeStmt - var content []byte - - if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil { + if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil { if err == sql.ErrNoRows { return nil, nil } - return } - - data = &gomatrixserverlib.ClientEvent{ - Type: dataType, - Content: content, - } - + data = json.RawMessage(bytes) return } diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 4dd755a7..dbf6606c 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -17,6 +17,7 @@ package sqlite3 import ( "context" "database/sql" + "encoding/json" "errors" "strconv" "sync" @@ -180,7 +181,7 @@ func (d *Database) createAccount( return nil, err } - if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{ + if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ "global": { "content": [], "override": [], @@ -188,7 +189,7 @@ func (d *Database) createAccount( "sender": [], "underride": [] } - }`); err != nil { + }`)); err != nil { return nil, err } return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) @@ -306,7 +307,7 @@ func (d *Database) newMembership( // update the corresponding row with the new content // Returns a SQL error if there was an issue with the insertion/update func (d *Database) SaveAccountData( - ctx context.Context, localpart, roomID, dataType, content string, + ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ) error { return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) @@ -317,8 +318,8 @@ func (d *Database) SaveAccountData( // If no account data could be found, returns an empty arrays // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountData(ctx context.Context, localpart string) ( - global []gomatrixserverlib.ClientEvent, - rooms map[string][]gomatrixserverlib.ClientEvent, + global map[string]json.RawMessage, + rooms map[string]map[string]json.RawMessage, err error, ) { return d.accountDatas.selectAccountData(ctx, localpart) @@ -330,7 +331,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) ( // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountDataByType( ctx context.Context, localpart, roomID, dataType string, -) (data *gomatrixserverlib.ClientEvent, err error) { +) (data json.RawMessage, err error) { return d.accountDatas.selectAccountDataByType( ctx, localpart, roomID, dataType, )