mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-07-30 04:52:46 +00:00
Refactor account data (#1150)
* Refactor account data * Tweak database fetching * Tweaks * Restore syncProducer notification * Various tweaks, update tag behaviour * Fix initial sync
This commit is contained in:
parent
3547a1768c
commit
dc0bac85d5
12 changed files with 248 additions and 222 deletions
|
@ -16,12 +16,14 @@ package api
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// UserInternalAPI is the internal API for information about users and devices.
|
||||
type UserInternalAPI interface {
|
||||
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
|
||||
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
|
||||
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
|
||||
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
||||
|
@ -30,6 +32,18 @@ type UserInternalAPI interface {
|
|||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||
}
|
||||
|
||||
// InputAccountDataRequest is the request for InputAccountData
|
||||
type InputAccountDataRequest struct {
|
||||
UserID string // required: the user to set account data for
|
||||
RoomID string // optional: the room to associate the account data with
|
||||
DataType string // optional: the data type of the data
|
||||
AccountData json.RawMessage // required: the message content
|
||||
}
|
||||
|
||||
// InputAccountDataResponse is the response for InputAccountData
|
||||
type InputAccountDataResponse struct {
|
||||
}
|
||||
|
||||
// QueryAccessTokenRequest is the request for QueryAccessToken
|
||||
type QueryAccessTokenRequest struct {
|
||||
AccessToken string
|
||||
|
@ -46,18 +60,15 @@ type QueryAccessTokenResponse struct {
|
|||
|
||||
// QueryAccountDataRequest is the request for QueryAccountData
|
||||
type QueryAccountDataRequest struct {
|
||||
UserID string // required: the user to get account data for.
|
||||
// TODO: This is a terribly confusing API shape :/
|
||||
DataType string // optional: if specified returns only a single event matching this data type.
|
||||
// optional: Only used if DataType is set. If blank returns global account data matching the data type.
|
||||
// If set, returns only room account data matching this data type.
|
||||
RoomID string
|
||||
UserID string // required: the user to get account data for.
|
||||
RoomID string // optional: the room ID, or global account data if not specified.
|
||||
DataType string // optional: the data type, or all types if not specified.
|
||||
}
|
||||
|
||||
// QueryAccountDataResponse is the response for QueryAccountData
|
||||
type QueryAccountDataResponse struct {
|
||||
GlobalAccountData []gomatrixserverlib.ClientEvent
|
||||
RoomAccountData map[string][]gomatrixserverlib.ClientEvent
|
||||
GlobalAccountData map[string]json.RawMessage // type -> data
|
||||
RoomAccountData map[string]map[string]json.RawMessage // room -> type -> data
|
||||
}
|
||||
|
||||
// QueryDevicesRequest is the request for QueryDevices
|
||||
|
|
|
@ -17,6 +17,7 @@ package internal
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
|
@ -38,6 +39,20 @@ type UserInternalAPI struct {
|
|||
AppServices []config.ApplicationService
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
|
||||
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if domain != a.ServerName {
|
||||
return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName)
|
||||
}
|
||||
if req.DataType == "" {
|
||||
return fmt.Errorf("data type must not be empty")
|
||||
}
|
||||
return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
|
||||
if req.AccountType == api.AccountTypeGuest {
|
||||
acc, err := a.AccountDB.CreateGuestAccount(ctx)
|
||||
|
@ -130,17 +145,21 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
|
|||
return fmt.Errorf("cannot query account data of remote users: got %s want %s", domain, a.ServerName)
|
||||
}
|
||||
if req.DataType != "" {
|
||||
var event *gomatrixserverlib.ClientEvent
|
||||
event, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
|
||||
var data json.RawMessage
|
||||
data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if event != nil {
|
||||
res.RoomAccountData = make(map[string]map[string]json.RawMessage)
|
||||
res.GlobalAccountData = make(map[string]json.RawMessage)
|
||||
if data != nil {
|
||||
if req.RoomID != "" {
|
||||
res.RoomAccountData = make(map[string][]gomatrixserverlib.ClientEvent)
|
||||
res.RoomAccountData[req.RoomID] = []gomatrixserverlib.ClientEvent{*event}
|
||||
if _, ok := res.RoomAccountData[req.RoomID]; !ok {
|
||||
res.RoomAccountData[req.RoomID] = make(map[string]json.RawMessage)
|
||||
}
|
||||
res.RoomAccountData[req.RoomID][req.DataType] = data
|
||||
} else {
|
||||
res.GlobalAccountData = append(res.GlobalAccountData, *event)
|
||||
res.GlobalAccountData[req.DataType] = data
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -26,6 +26,8 @@ import (
|
|||
|
||||
// HTTP paths for the internal HTTP APIs
|
||||
const (
|
||||
InputAccountDataPath = "/userapi/inputAccountData"
|
||||
|
||||
PerformDeviceCreationPath = "/userapi/performDeviceCreation"
|
||||
PerformAccountCreationPath = "/userapi/performAccountCreation"
|
||||
|
||||
|
@ -55,6 +57,14 @@ type httpUserInternalAPI struct {
|
|||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func (h *httpUserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "InputAccountData")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.apiURL + InputAccountDataPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
||||
func (h *httpUserInternalAPI) PerformAccountCreation(
|
||||
ctx context.Context,
|
||||
request *api.PerformAccountCreationRequest,
|
||||
|
|
|
@ -16,6 +16,7 @@ package accounts
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
|
@ -39,13 +40,13 @@ type Database interface {
|
|||
GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error)
|
||||
GetRoomIDsByLocalPart(ctx context.Context, localpart string) ([]string, error)
|
||||
GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error)
|
||||
SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error
|
||||
GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error)
|
||||
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
|
||||
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
|
||||
// GetAccountDataByType returns account data matching a given
|
||||
// localpart, room ID and type.
|
||||
// If no account data could be found, returns nil
|
||||
// Returns an error if there was an issue with the retrieval
|
||||
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data *gomatrixserverlib.ClientEvent, err error)
|
||||
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
|
||||
GetNewNumericLocalpart(ctx context.Context) (int64, error)
|
||||
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
|
||||
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
|
||||
|
|
|
@ -17,9 +17,9 @@ package postgres
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const accountDataSchema = `
|
||||
|
@ -73,7 +73,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
|||
}
|
||||
|
||||
func (s *accountDataStatements) insertAccountData(
|
||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string,
|
||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
||||
) (err error) {
|
||||
stmt := txn.Stmt(s.insertAccountDataStmt)
|
||||
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
|
||||
|
@ -83,18 +83,18 @@ func (s *accountDataStatements) insertAccountData(
|
|||
func (s *accountDataStatements) selectAccountData(
|
||||
ctx context.Context, localpart string,
|
||||
) (
|
||||
global []gomatrixserverlib.ClientEvent,
|
||||
rooms map[string][]gomatrixserverlib.ClientEvent,
|
||||
err error,
|
||||
/* global */ map[string]json.RawMessage,
|
||||
/* rooms */ map[string]map[string]json.RawMessage,
|
||||
error,
|
||||
) {
|
||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountData: rows.close() failed")
|
||||
|
||||
global = []gomatrixserverlib.ClientEvent{}
|
||||
rooms = make(map[string][]gomatrixserverlib.ClientEvent)
|
||||
global := map[string]json.RawMessage{}
|
||||
rooms := map[string]map[string]json.RawMessage{}
|
||||
|
||||
for rows.Next() {
|
||||
var roomID string
|
||||
|
@ -102,41 +102,33 @@ func (s *accountDataStatements) selectAccountData(
|
|||
var content []byte
|
||||
|
||||
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
|
||||
return
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
ac := gomatrixserverlib.ClientEvent{
|
||||
Type: dataType,
|
||||
Content: content,
|
||||
}
|
||||
|
||||
if len(roomID) > 0 {
|
||||
rooms[roomID] = append(rooms[roomID], ac)
|
||||
if roomID != "" {
|
||||
if _, ok := rooms[roomID]; !ok {
|
||||
rooms[roomID] = map[string]json.RawMessage{}
|
||||
}
|
||||
rooms[roomID][dataType] = content
|
||||
} else {
|
||||
global = append(global, ac)
|
||||
global[dataType] = content
|
||||
}
|
||||
}
|
||||
|
||||
return global, rooms, rows.Err()
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) selectAccountDataByType(
|
||||
ctx context.Context, localpart, roomID, dataType string,
|
||||
) (data *gomatrixserverlib.ClientEvent, err error) {
|
||||
) (data json.RawMessage, err error) {
|
||||
var bytes []byte
|
||||
stmt := s.selectAccountDataByTypeStmt
|
||||
var content []byte
|
||||
|
||||
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil {
|
||||
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
data = &gomatrixserverlib.ClientEvent{
|
||||
Type: dataType,
|
||||
Content: content,
|
||||
}
|
||||
|
||||
data = json.RawMessage(bytes)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ package postgres
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
|
||||
|
@ -169,7 +170,7 @@ func (d *Database) createAccount(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{
|
||||
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
|
||||
"global": {
|
||||
"content": [],
|
||||
"override": [],
|
||||
|
@ -177,7 +178,7 @@ func (d *Database) createAccount(
|
|||
"sender": [],
|
||||
"underride": []
|
||||
}
|
||||
}`); err != nil {
|
||||
}`)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
|
||||
|
@ -295,7 +296,7 @@ func (d *Database) newMembership(
|
|||
// update the corresponding row with the new content
|
||||
// Returns a SQL error if there was an issue with the insertion/update
|
||||
func (d *Database) SaveAccountData(
|
||||
ctx context.Context, localpart, roomID, dataType, content string,
|
||||
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
|
||||
) error {
|
||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
|
||||
|
@ -306,8 +307,8 @@ func (d *Database) SaveAccountData(
|
|||
// If no account data could be found, returns an empty arrays
|
||||
// Returns an error if there was an issue with the retrieval
|
||||
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
||||
global []gomatrixserverlib.ClientEvent,
|
||||
rooms map[string][]gomatrixserverlib.ClientEvent,
|
||||
global map[string]json.RawMessage,
|
||||
rooms map[string]map[string]json.RawMessage,
|
||||
err error,
|
||||
) {
|
||||
return d.accountDatas.selectAccountData(ctx, localpart)
|
||||
|
@ -319,7 +320,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
|||
// Returns an error if there was an issue with the retrieval
|
||||
func (d *Database) GetAccountDataByType(
|
||||
ctx context.Context, localpart, roomID, dataType string,
|
||||
) (data *gomatrixserverlib.ClientEvent, err error) {
|
||||
) (data json.RawMessage, err error) {
|
||||
return d.accountDatas.selectAccountDataByType(
|
||||
ctx, localpart, roomID, dataType,
|
||||
)
|
||||
|
|
|
@ -17,8 +17,7 @@ package sqlite3
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
const accountDataSchema = `
|
||||
|
@ -72,7 +71,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
|||
}
|
||||
|
||||
func (s *accountDataStatements) insertAccountData(
|
||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string,
|
||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
||||
) (err error) {
|
||||
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
|
||||
return
|
||||
|
@ -81,17 +80,17 @@ func (s *accountDataStatements) insertAccountData(
|
|||
func (s *accountDataStatements) selectAccountData(
|
||||
ctx context.Context, localpart string,
|
||||
) (
|
||||
global []gomatrixserverlib.ClientEvent,
|
||||
rooms map[string][]gomatrixserverlib.ClientEvent,
|
||||
err error,
|
||||
/* global */ map[string]json.RawMessage,
|
||||
/* rooms */ map[string]map[string]json.RawMessage,
|
||||
error,
|
||||
) {
|
||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
global = []gomatrixserverlib.ClientEvent{}
|
||||
rooms = make(map[string][]gomatrixserverlib.ClientEvent)
|
||||
global := map[string]json.RawMessage{}
|
||||
rooms := map[string]map[string]json.RawMessage{}
|
||||
|
||||
for rows.Next() {
|
||||
var roomID string
|
||||
|
@ -99,42 +98,33 @@ func (s *accountDataStatements) selectAccountData(
|
|||
var content []byte
|
||||
|
||||
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
|
||||
return
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
ac := gomatrixserverlib.ClientEvent{
|
||||
Type: dataType,
|
||||
Content: content,
|
||||
}
|
||||
|
||||
if len(roomID) > 0 {
|
||||
rooms[roomID] = append(rooms[roomID], ac)
|
||||
if roomID != "" {
|
||||
if _, ok := rooms[roomID]; !ok {
|
||||
rooms[roomID] = map[string]json.RawMessage{}
|
||||
}
|
||||
rooms[roomID][dataType] = content
|
||||
} else {
|
||||
global = append(global, ac)
|
||||
global[dataType] = content
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
return global, rooms, nil
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) selectAccountDataByType(
|
||||
ctx context.Context, localpart, roomID, dataType string,
|
||||
) (data *gomatrixserverlib.ClientEvent, err error) {
|
||||
) (data json.RawMessage, err error) {
|
||||
var bytes []byte
|
||||
stmt := s.selectAccountDataByTypeStmt
|
||||
var content []byte
|
||||
|
||||
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil {
|
||||
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
data = &gomatrixserverlib.ClientEvent{
|
||||
Type: dataType,
|
||||
Content: content,
|
||||
}
|
||||
|
||||
data = json.RawMessage(bytes)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ package sqlite3
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
@ -180,7 +181,7 @@ func (d *Database) createAccount(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{
|
||||
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
|
||||
"global": {
|
||||
"content": [],
|
||||
"override": [],
|
||||
|
@ -188,7 +189,7 @@ func (d *Database) createAccount(
|
|||
"sender": [],
|
||||
"underride": []
|
||||
}
|
||||
}`); err != nil {
|
||||
}`)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
|
||||
|
@ -306,7 +307,7 @@ func (d *Database) newMembership(
|
|||
// update the corresponding row with the new content
|
||||
// Returns a SQL error if there was an issue with the insertion/update
|
||||
func (d *Database) SaveAccountData(
|
||||
ctx context.Context, localpart, roomID, dataType, content string,
|
||||
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
|
||||
) error {
|
||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
|
||||
|
@ -317,8 +318,8 @@ func (d *Database) SaveAccountData(
|
|||
// If no account data could be found, returns an empty arrays
|
||||
// Returns an error if there was an issue with the retrieval
|
||||
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
||||
global []gomatrixserverlib.ClientEvent,
|
||||
rooms map[string][]gomatrixserverlib.ClientEvent,
|
||||
global map[string]json.RawMessage,
|
||||
rooms map[string]map[string]json.RawMessage,
|
||||
err error,
|
||||
) {
|
||||
return d.accountDatas.selectAccountData(ctx, localpart)
|
||||
|
@ -330,7 +331,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
|||
// Returns an error if there was an issue with the retrieval
|
||||
func (d *Database) GetAccountDataByType(
|
||||
ctx context.Context, localpart, roomID, dataType string,
|
||||
) (data *gomatrixserverlib.ClientEvent, err error) {
|
||||
) (data json.RawMessage, err error) {
|
||||
return d.accountDatas.selectAccountDataByType(
|
||||
ctx, localpart, roomID, dataType,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue