Refactor account data (#1150)

* Refactor account data

* Tweak database fetching

* Tweaks

* Restore syncProducer notification

* Various tweaks, update tag behaviour

* Fix initial sync
This commit is contained in:
Neil Alexander 2020-06-18 18:36:03 +01:00 committed by GitHub
parent 3547a1768c
commit dc0bac85d5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 248 additions and 222 deletions

View file

@ -17,8 +17,7 @@ package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/gomatrixserverlib"
"encoding/json"
)
const accountDataSchema = `
@ -72,7 +71,7 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
}
func (s *accountDataStatements) insertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string,
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) (err error) {
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
return
@ -81,17 +80,17 @@ func (s *accountDataStatements) insertAccountData(
func (s *accountDataStatements) selectAccountData(
ctx context.Context, localpart string,
) (
global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent,
err error,
/* global */ map[string]json.RawMessage,
/* rooms */ map[string]map[string]json.RawMessage,
error,
) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
if err != nil {
return
return nil, nil, err
}
global = []gomatrixserverlib.ClientEvent{}
rooms = make(map[string][]gomatrixserverlib.ClientEvent)
global := map[string]json.RawMessage{}
rooms := map[string]map[string]json.RawMessage{}
for rows.Next() {
var roomID string
@ -99,42 +98,33 @@ func (s *accountDataStatements) selectAccountData(
var content []byte
if err = rows.Scan(&roomID, &dataType, &content); err != nil {
return
return nil, nil, err
}
ac := gomatrixserverlib.ClientEvent{
Type: dataType,
Content: content,
}
if len(roomID) > 0 {
rooms[roomID] = append(rooms[roomID], ac)
if roomID != "" {
if _, ok := rooms[roomID]; !ok {
rooms[roomID] = map[string]json.RawMessage{}
}
rooms[roomID][dataType] = content
} else {
global = append(global, ac)
global[dataType] = content
}
}
return
return global, rooms, nil
}
func (s *accountDataStatements) selectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) {
) (data json.RawMessage, err error) {
var bytes []byte
stmt := s.selectAccountDataByTypeStmt
var content []byte
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&content); err != nil {
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return
}
data = &gomatrixserverlib.ClientEvent{
Type: dataType,
Content: content,
}
data = json.RawMessage(bytes)
return
}

View file

@ -17,6 +17,7 @@ package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"errors"
"strconv"
"sync"
@ -180,7 +181,7 @@ func (d *Database) createAccount(
return nil, err
}
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{
if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": {
"content": [],
"override": [],
@ -188,7 +189,7 @@ func (d *Database) createAccount(
"sender": [],
"underride": []
}
}`); err != nil {
}`)); err != nil {
return nil, err
}
return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
@ -306,7 +307,7 @@ func (d *Database) newMembership(
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string,
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
@ -317,8 +318,8 @@ func (d *Database) SaveAccountData(
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global []gomatrixserverlib.ClientEvent,
rooms map[string][]gomatrixserverlib.ClientEvent,
global map[string]json.RawMessage,
rooms map[string]map[string]json.RawMessage,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
@ -330,7 +331,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data *gomatrixserverlib.ClientEvent, err error) {
) (data json.RawMessage, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)