Send all account data on complete sync by default

Squashed commit of the following:

commit 0ec8de57261d573a5f88577aa9d7a1174d3999b9
Author: Neil Alexander <neilalexander@users.noreply.github.com>
Date:   Tue Apr 26 16:56:30 2022 +0100

    Select filter onto provided target filter

commit da40b6fffbf5737864b223f49900048f557941f9
Author: Neil Alexander <neilalexander@users.noreply.github.com>
Date:   Tue Apr 26 16:48:00 2022 +0100

    Specify other field too

commit ffc0b0801f63bb4d3061b6813e3ce5f3b4c8fbcb
Author: Neil Alexander <neilalexander@users.noreply.github.com>
Date:   Tue Apr 26 16:45:44 2022 +0100

    Send as much account data as possible during complete sync
This commit is contained in:
Neil Alexander 2022-04-26 16:58:20 +01:00
parent f6d07768a8
commit b527e33c16
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
7 changed files with 30 additions and 26 deletions

View file

@ -44,8 +44,8 @@ func GetFilter(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
filter, err := syncDB.GetFilter(req.Context(), localpart, filterID) filter := gomatrixserverlib.DefaultFilter()
if err != nil { if err := syncDB.GetFilter(req.Context(), &filter, localpart, filterID); err != nil {
//TODO better error handling. This error message is *probably* right, //TODO better error handling. This error message is *probably* right,
// but if there are obscure db errors, this will also be returned, // but if there are obscure db errors, this will also be returned,
// even though it is not correct. // even though it is not correct.

View file

@ -125,10 +125,10 @@ type Database interface {
// CleanSendToDeviceUpdates removes all send-to-device messages BEFORE the specified // CleanSendToDeviceUpdates removes all send-to-device messages BEFORE the specified
// from position, preventing the send-to-device table from growing indefinitely. // from position, preventing the send-to-device table from growing indefinitely.
CleanSendToDeviceUpdates(ctx context.Context, userID, deviceID string, before types.StreamPosition) (err error) CleanSendToDeviceUpdates(ctx context.Context, userID, deviceID string, before types.StreamPosition) (err error)
// GetFilter looks up the filter associated with a given local user and filter ID. // GetFilter looks up the filter associated with a given local user and filter ID
// Returns a filter structure. Otherwise returns an error if no such filter exists // and populates the target filter. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database. // or if there was an error talking to the database.
GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error) GetFilter(ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string) error
// PutFilter puts the passed filter into the database. // PutFilter puts the passed filter into the database.
// Returns the filterID as a string. Otherwise returns an error if something // Returns the filterID as a string. Otherwise returns an error if something
// goes wrong. // goes wrong.

View file

@ -73,21 +73,20 @@ func NewPostgresFilterTable(db *sql.DB) (tables.Filter, error) {
} }
func (s *filterStatements) SelectFilter( func (s *filterStatements) SelectFilter(
ctx context.Context, localpart string, filterID string, ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) { ) error {
// Retrieve filter from database (stored as canonical JSON) // Retrieve filter from database (stored as canonical JSON)
var filterData []byte var filterData []byte
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData) err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
if err != nil { if err != nil {
return nil, err return err
} }
// Unmarshal JSON into Filter struct // Unmarshal JSON into Filter struct
filter := gomatrixserverlib.DefaultFilter() if err = json.Unmarshal(filterData, &target); err != nil {
if err = json.Unmarshal(filterData, &filter); err != nil { return err
return nil, err
} }
return &filter, nil return nil
} }
func (s *filterStatements) InsertFilter( func (s *filterStatements) InsertFilter(

View file

@ -513,9 +513,9 @@ func (d *Database) StreamToTopologicalPosition(
} }
func (d *Database) GetFilter( func (d *Database) GetFilter(
ctx context.Context, localpart string, filterID string, ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) { ) error {
return d.Filter.SelectFilter(ctx, localpart, filterID) return d.Filter.SelectFilter(ctx, target, localpart, filterID)
} }
func (d *Database) PutFilter( func (d *Database) PutFilter(

View file

@ -77,21 +77,20 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
} }
func (s *filterStatements) SelectFilter( func (s *filterStatements) SelectFilter(
ctx context.Context, localpart string, filterID string, ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string,
) (*gomatrixserverlib.Filter, error) { ) error {
// Retrieve filter from database (stored as canonical JSON) // Retrieve filter from database (stored as canonical JSON)
var filterData []byte var filterData []byte
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData) err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
if err != nil { if err != nil {
return nil, err return err
} }
// Unmarshal JSON into Filter struct // Unmarshal JSON into Filter struct
filter := gomatrixserverlib.DefaultFilter() if err = json.Unmarshal(filterData, &target); err != nil {
if err = json.Unmarshal(filterData, &filter); err != nil { return err
return nil, err
} }
return &filter, nil return nil
} }
func (s *filterStatements) InsertFilter( func (s *filterStatements) InsertFilter(

View file

@ -157,7 +157,7 @@ type SendToDevice interface {
} }
type Filter interface { type Filter interface {
SelectFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error) SelectFilter(ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string) error
InsertFilter(ctx context.Context, filter *gomatrixserverlib.Filter, localpart string) (filterID string, err error) InsertFilter(ctx context.Context, filter *gomatrixserverlib.Filter, localpart string) (filterID string, err error)
} }

View file

@ -18,6 +18,7 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"math"
"net/http" "net/http"
"strconv" "strconv"
"time" "time"
@ -47,6 +48,13 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
} }
// TODO: read from stored filters too // TODO: read from stored filters too
filter := gomatrixserverlib.DefaultFilter() filter := gomatrixserverlib.DefaultFilter()
if since.IsEmpty() {
// Send as much account data down for complete syncs as possible
// by default, otherwise clients do weird things while waiting
// for the rest of the data to trickle down.
filter.AccountData.Limit = math.MaxInt
filter.Room.AccountData.Limit = math.MaxInt
}
filterQuery := req.URL.Query().Get("filter") filterQuery := req.URL.Query().Get("filter")
if filterQuery != "" { if filterQuery != "" {
if filterQuery[0] == '{' { if filterQuery[0] == '{' {
@ -61,11 +69,9 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err) return nil, fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
} }
if f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery); err != nil && err != sql.ErrNoRows { if err := syncDB.GetFilter(req.Context(), &filter, localpart, filterQuery); err != nil && err != sql.ErrNoRows {
util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed") util.GetLogger(req.Context()).WithError(err).Error("syncDB.GetFilter failed")
return nil, fmt.Errorf("syncDB.GetFilter: %w", err) return nil, fmt.Errorf("syncDB.GetFilter: %w", err)
} else if f != nil {
filter = *f
} }
} }
} }