Refactor account data (#1150)

* Refactor account data

* Tweak database fetching

* Tweaks

* Restore syncProducer notification

* Various tweaks, update tag behaviour

* Fix initial sync
This commit is contained in:
Neil Alexander 2020-06-18 18:36:03 +01:00 committed by GitHub
parent 3547a1768c
commit dc0bac85d5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 248 additions and 222 deletions

View file

@ -16,21 +16,20 @@ package routing
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// GetAccountData implements GET /user/{userId}/[rooms/{roomid}/]account_data/{type}
func GetAccountData(
req *http.Request, accountDB accounts.Database, device *api.Device,
req *http.Request, userAPI api.UserInternalAPI, device *api.Device,
userID string, roomID string, dataType string,
) util.JSONResponse {
if userID != device.UserID {
@ -40,18 +39,28 @@ func GetAccountData(
}
}
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
dataReq := api.QueryAccountDataRequest{
UserID: userID,
DataType: dataType,
RoomID: roomID,
}
dataRes := api.QueryAccountDataResponse{}
if err := userAPI.QueryAccountData(req.Context(), &dataReq, &dataRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed")
return util.ErrorResponse(fmt.Errorf("userAPI.QueryAccountData: %w", err))
}
if data, err := accountDB.GetAccountDataByType(
req.Context(), localpart, roomID, dataType,
); err == nil {
var data json.RawMessage
var ok bool
if roomID != "" {
data, ok = dataRes.RoomAccountData[roomID][dataType]
} else {
data, ok = dataRes.GlobalAccountData[dataType]
}
if ok {
return util.JSONResponse{
Code: http.StatusOK,
JSON: data.Content,
JSON: data,
}
}
@ -63,7 +72,7 @@ func GetAccountData(
// SaveAccountData implements PUT /user/{userId}/[rooms/{roomId}/]account_data/{type}
func SaveAccountData(
req *http.Request, accountDB accounts.Database, device *api.Device,
req *http.Request, userAPI api.UserInternalAPI, device *api.Device,
userID string, roomID string, dataType string, syncProducer *producers.SyncAPIProducer,
) util.JSONResponse {
if userID != device.UserID {
@ -73,12 +82,6 @@ func SaveAccountData(
}
}
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
}
defer req.Body.Close() // nolint: errcheck
if req.Body == http.NoBody {
@ -101,13 +104,19 @@ func SaveAccountData(
}
}
if err := accountDB.SaveAccountData(
req.Context(), localpart, roomID, dataType, string(body),
); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("accountDB.SaveAccountData failed")
return jsonerror.InternalServerError()
dataReq := api.InputAccountDataRequest{
UserID: userID,
DataType: dataType,
RoomID: roomID,
AccountData: json.RawMessage(body),
}
dataRes := api.InputAccountDataResponse{}
if err := userAPI.InputAccountData(req.Context(), &dataReq, &dataRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryAccountData failed")
return util.ErrorResponse(err)
}
// TODO: user API should do this since it's account data
if err := syncProducer.SendData(userID, roomID, dataType); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
return jsonerror.InternalServerError()

View file

@ -24,23 +24,14 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// newTag creates and returns a new gomatrix.TagContent
func newTag() gomatrix.TagContent {
return gomatrix.TagContent{
Tags: make(map[string]gomatrix.TagProperties),
}
}
// GetTags implements GET /_matrix/client/r0/user/{userID}/rooms/{roomID}/tags
func GetTags(
req *http.Request,
accountDB accounts.Database,
userAPI api.UserInternalAPI,
device *api.Device,
userID string,
roomID string,
@ -54,22 +45,15 @@ func GetTags(
}
}
_, data, err := obtainSavedTags(req, userID, roomID, accountDB)
tagContent, err := obtainSavedTags(req, userID, roomID, userAPI)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed")
return jsonerror.InternalServerError()
}
if data == nil {
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: data.Content,
JSON: tagContent,
}
}
@ -78,7 +62,7 @@ func GetTags(
// the tag to the "map" and saving the new "map" to the DB
func PutTag(
req *http.Request,
accountDB accounts.Database,
userAPI api.UserInternalAPI,
device *api.Device,
userID string,
roomID string,
@ -98,34 +82,25 @@ func PutTag(
return *reqErr
}
localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB)
tagContent, err := obtainSavedTags(req, userID, roomID, userAPI)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed")
return jsonerror.InternalServerError()
}
var tagContent gomatrix.TagContent
if data != nil {
if err = json.Unmarshal(data.Content, &tagContent); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed")
return jsonerror.InternalServerError()
}
} else {
tagContent = newTag()
if tagContent.Tags == nil {
tagContent.Tags = make(map[string]gomatrix.TagProperties)
}
tagContent.Tags[tag] = properties
if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil {
if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
return jsonerror.InternalServerError()
}
// Send data to syncProducer in order to inform clients of changes
// Run in a goroutine in order to prevent blocking the tag request response
go func() {
if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
}
}()
if err = syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
}
return util.JSONResponse{
Code: http.StatusOK,
@ -138,7 +113,7 @@ func PutTag(
// the "map" and then saving the new "map" in the DB
func DeleteTag(
req *http.Request,
accountDB accounts.Database,
userAPI api.UserInternalAPI,
device *api.Device,
userID string,
roomID string,
@ -153,28 +128,12 @@ func DeleteTag(
}
}
localpart, data, err := obtainSavedTags(req, userID, roomID, accountDB)
tagContent, err := obtainSavedTags(req, userID, roomID, userAPI)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("obtainSavedTags failed")
return jsonerror.InternalServerError()
}
// If there are no tags in the database, exit
if data == nil {
// Spec only defines 200 responses for this endpoint so we don't return anything else.
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
var tagContent gomatrix.TagContent
err = json.Unmarshal(data.Content, &tagContent)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("json.Unmarshal failed")
return jsonerror.InternalServerError()
}
// Check whether the tag to be deleted exists
if _, ok := tagContent.Tags[tag]; ok {
delete(tagContent.Tags, tag)
@ -185,18 +144,16 @@ func DeleteTag(
JSON: struct{}{},
}
}
if err = saveTagData(req, localpart, roomID, accountDB, tagContent); err != nil {
if err = saveTagData(req, userID, roomID, userAPI, tagContent); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("saveTagData failed")
return jsonerror.InternalServerError()
}
// Send data to syncProducer in order to inform clients of changes
// Run in a goroutine in order to prevent blocking the tag request response
go func() {
if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
}
}()
// TODO: user API should do this since it's account data
if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil {
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
}
return util.JSONResponse{
Code: http.StatusOK,
@ -210,32 +167,46 @@ func obtainSavedTags(
req *http.Request,
userID string,
roomID string,
accountDB accounts.Database,
) (string, *gomatrixserverlib.ClientEvent, error) {
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return "", nil, err
userAPI api.UserInternalAPI,
) (tags gomatrix.TagContent, err error) {
dataReq := api.QueryAccountDataRequest{
UserID: userID,
RoomID: roomID,
DataType: "m.tag",
}
data, err := accountDB.GetAccountDataByType(
req.Context(), localpart, roomID, "m.tag",
)
return localpart, data, err
dataRes := api.QueryAccountDataResponse{}
err = userAPI.QueryAccountData(req.Context(), &dataReq, &dataRes)
if err != nil {
return
}
data, ok := dataRes.RoomAccountData[roomID]["m.tag"]
if !ok {
return
}
if err = json.Unmarshal(data, &tags); err != nil {
return
}
return tags, nil
}
// saveTagData saves the provided tag data into the database
func saveTagData(
req *http.Request,
localpart string,
userID string,
roomID string,
accountDB accounts.Database,
userAPI api.UserInternalAPI,
Tag gomatrix.TagContent,
) error {
newTagData, err := json.Marshal(Tag)
if err != nil {
return err
}
return accountDB.SaveAccountData(req.Context(), localpart, roomID, "m.tag", string(newTagData))
dataReq := api.InputAccountDataRequest{
UserID: userID,
RoomID: roomID,
DataType: "m.tag",
AccountData: json.RawMessage(newTagData),
}
dataRes := api.InputAccountDataResponse{}
return userAPI.InputAccountData(req.Context(), &dataReq, &dataRes)
}

View file

@ -476,7 +476,7 @@ func Setup(
if err != nil {
return util.ErrorResponse(err)
}
return SaveAccountData(req, accountDB, device, vars["userID"], "", vars["type"], syncProducer)
return SaveAccountData(req, userAPI, device, vars["userID"], "", vars["type"], syncProducer)
}),
).Methods(http.MethodPut, http.MethodOptions)
@ -486,7 +486,7 @@ func Setup(
if err != nil {
return util.ErrorResponse(err)
}
return SaveAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"], syncProducer)
return SaveAccountData(req, userAPI, device, vars["userID"], vars["roomID"], vars["type"], syncProducer)
}),
).Methods(http.MethodPut, http.MethodOptions)
@ -496,7 +496,7 @@ func Setup(
if err != nil {
return util.ErrorResponse(err)
}
return GetAccountData(req, accountDB, device, vars["userID"], "", vars["type"])
return GetAccountData(req, userAPI, device, vars["userID"], "", vars["type"])
}),
).Methods(http.MethodGet)
@ -506,7 +506,7 @@ func Setup(
if err != nil {
return util.ErrorResponse(err)
}
return GetAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"])
return GetAccountData(req, userAPI, device, vars["userID"], vars["roomID"], vars["type"])
}),
).Methods(http.MethodGet)
@ -604,7 +604,7 @@ func Setup(
if err != nil {
return util.ErrorResponse(err)
}
return GetTags(req, accountDB, device, vars["userId"], vars["roomId"], syncProducer)
return GetTags(req, userAPI, device, vars["userId"], vars["roomId"], syncProducer)
}),
).Methods(http.MethodGet, http.MethodOptions)
@ -614,7 +614,7 @@ func Setup(
if err != nil {
return util.ErrorResponse(err)
}
return PutTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
return PutTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
}),
).Methods(http.MethodPut, http.MethodOptions)
@ -624,7 +624,7 @@ func Setup(
if err != nil {
return util.ErrorResponse(err)
}
return DeleteTag(req, accountDB, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
return DeleteTag(req, userAPI, device, vars["userId"], vars["roomId"], vars["tag"], syncProducer)
}),
).Methods(http.MethodDelete, http.MethodOptions)