Refactor account data (#1150)

* Refactor account data

* Tweak database fetching

* Tweaks

* Restore syncProducer notification

* Various tweaks, update tag behaviour

* Fix initial sync
This commit is contained in:
Neil Alexander 2020-06-18 18:36:03 +01:00 committed by GitHub
parent 3547a1768c
commit dc0bac85d5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 248 additions and 222 deletions

View file

@ -16,6 +16,7 @@ package accounts
import (
"context"
"encoding/json"
"errors"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -39,13 +40,13 @@ type Database interface {
GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error)
GetRoomIDsByLocalPart(ctx context.Context, localpart string) ([]string, error)
GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error)
SaveAccountData(ctx context.Context, localpart, roomID, dataType, content string) error
GetAccountData(ctx context.Context, localpart string) (global []gomatrixserverlib.ClientEvent, rooms map[string][]gomatrixserverlib.ClientEvent, err error)
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
// GetAccountDataByType returns account data matching a given
// localpart, room ID and type.
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data *gomatrixserverlib.ClientEvent, err error)
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
GetNewNumericLocalpart(ctx context.Context) (int64, error)
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)

View file

@ -17,9 +17,9 @@ package postgres
import (
"context"
"database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/gomatrixserverlib"
)
const accountDataSchema = `
@ -73,7 +73,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
}
func (s *accountDataStatements) insertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string,
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) (err error) {
stmt := txn.Stmt(s.insertAccountDataStmt)
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
@ -83,18 +83,18 @@ func (s *accountDataStatements) insertAccountData(
func (s *accountDataStatements) selectAccountData(
ctx context.Context, localpart string,
) (
global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent,
err error,
/* global */ map[string]json.RawMessage,
/* rooms */ map[string]map[string]json.RawMessage,
error,
) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
if err != nil {
return
return nil, nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountData: rows.close() failed")
global = []gomatrixserverlib.ClientEvent{}
rooms = make(map[string][]gomatrixserverlib.ClientEvent)
global := map[string]json.RawMessage{}
rooms := map[string]map[string]json.RawMessage{}
for rows.Next() {
var roomID string
@ -102,41 +102,33 @@ func (s *accountDataStatements) selectAccountData(
var content []byte
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
return
return nil, nil, err
}
ac := gomatrixserverlib.ClientEvent{
Type: dataType,
Content: content,
}
if len(roomID) > 0 {
rooms[roomID] = append(rooms[roomID], ac)
if roomID != "" {
if _, ok := rooms[roomID]; !ok {
rooms[roomID] = map[string]json.RawMessage{}
}
rooms[roomID][dataType] = content
} else {
global = append(global, ac)
global[dataType] = content
}
}
return global, rooms, rows.Err()
}
func (s *accountDataStatements) selectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) {
) (data json.RawMessage, err error) {
var bytes []byte
stmt := s.selectAccountDataByTypeStmt
var content []byte
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil {
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return
}
data = &gomatrixserverlib.ClientEvent{
Type: dataType,
Content: content,
}
data = json.RawMessage(bytes)
return
}

View file

@ -17,6 +17,7 @@ package postgres
import (
"context"
"database/sql"
"encoding/json"
"errors"
"strconv"
@ -169,7 +170,7 @@ func (d *Database) createAccount(
return nil, err
}
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": {
"content": [],
"override": [],
@ -177,7 +178,7 @@ func (d *Database) createAccount(
"sender": [],
"underride": []
}
}`); err != nil {
}`)); err != nil {
return nil, err
}
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
@ -295,7 +296,7 @@ func (d *Database) newMembership(
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string,
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
@ -306,8 +307,8 @@ func (d *Database) SaveAccountData(
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent,
global map[string]json.RawMessage,
rooms map[string]map[string]json.RawMessage,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
@ -319,7 +320,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) {
) (data json.RawMessage, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)

View file

@ -17,8 +17,7 @@ package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/gomatrixserverlib"
"encoding/json"
)
const accountDataSchema = `
@ -72,7 +71,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
}
func (s *accountDataStatements) insertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string,
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) (err error) {
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
return
@ -81,17 +80,17 @@ func (s *accountDataStatements) insertAccountData(
func (s *accountDataStatements) selectAccountData(
ctx context.Context, localpart string,
) (
global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent,
err error,
/* global */ map[string]json.RawMessage,
/* rooms */ map[string]map[string]json.RawMessage,
error,
) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
if err != nil {
return
return nil, nil, err
}
global = []gomatrixserverlib.ClientEvent{}
rooms = make(map[string][]gomatrixserverlib.ClientEvent)
global := map[string]json.RawMessage{}
rooms := map[string]map[string]json.RawMessage{}
for rows.Next() {
var roomID string
@ -99,42 +98,33 @@ func (s *accountDataStatements) selectAccountData(
var content []byte
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
return
return nil, nil, err
}
ac := gomatrixserverlib.ClientEvent{
Type: dataType,
Content: content,
}
if len(roomID) > 0 {
rooms[roomID] = append(rooms[roomID], ac)
if roomID != "" {
if _, ok := rooms[roomID]; !ok {
rooms[roomID] = map[string]json.RawMessage{}
}
rooms[roomID][dataType] = content
} else {
global = append(global, ac)
global[dataType] = content
}
}
return
return global, rooms, nil
}
func (s *accountDataStatements) selectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) {
) (data json.RawMessage, err error) {
var bytes []byte
stmt := s.selectAccountDataByTypeStmt
var content []byte
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil {
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return
}
data = &gomatrixserverlib.ClientEvent{
Type: dataType,
Content: content,
}
data = json.RawMessage(bytes)
return
}

View file

@ -17,6 +17,7 @@ package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"errors"
"strconv"
"sync"
@ -180,7 +181,7 @@ func (d *Database) createAccount(
return nil, err
}
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": {
"content": [],
"override": [],
@ -188,7 +189,7 @@ func (d *Database) createAccount(
"sender": [],
"underride": []
}
}`); err != nil {
}`)); err != nil {
return nil, err
}
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
@ -306,7 +307,7 @@ func (d *Database) newMembership(
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string,
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
@ -317,8 +318,8 @@ func (d *Database) SaveAccountData(
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent,
global map[string]json.RawMessage,
rooms map[string]map[string]json.RawMessage,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
@ -330,7 +331,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) {
) (data json.RawMessage, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)