From 8dc95062101b3906ffb83604e2abca02d9a3dd03 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Mon, 14 Sep 2020 16:39:38 +0100 Subject: [PATCH 01/12] Don't use more than 999 variables in SQLite querys. (#1425) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Don't use more than 999 variables in SQLite querys. Solve this problem in a more general and reusable way. Also fix #1369 Add some unit tests. Signed-off-by: Henrik Sölver * Don't rely on testify for basic assertions * Readability improvements and linting Co-authored-by: Henrik Sölver --- go.mod | 1 + go.sum | 2 + internal/sqlutil/sql.go | 45 +++++ internal/sqlutil/sqlutil_test.go | 173 ++++++++++++++++++ .../storage/sqlite3/server_key_table.go | 67 +++---- 5 files changed, 255 insertions(+), 33 deletions(-) create mode 100644 internal/sqlutil/sqlutil_test.go diff --git a/go.mod b/go.mod index f1cb3c9b..6b1c03b5 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,7 @@ module github.com/matrix-org/dendrite require ( + github.com/DATA-DOG/go-sqlmock v1.5.0 github.com/Shopify/sarama v1.27.0 github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect github.com/gologme/log v1.2.0 diff --git a/go.sum b/go.sum index ac7827d9..5c4f27a5 100644 --- a/go.sum +++ b/go.sum @@ -13,6 +13,8 @@ github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0 h1:p3puK8Sl2xK+2Fnn github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0/go.mod h1:6Lkn+/zJilRMsKmbmG1RPoamiArC6HS73xbwRyp3UyI= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= +github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/Kubuxu/go-os-helper v0.0.1/go.mod h1:N8B+I7vPCT80IcP58r50u4+gEEcsZETFUpAzWW2ep1Y= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc= diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index 1d2825d5..90562ded 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -15,10 +15,14 @@ package sqlutil import ( + "context" "database/sql" "errors" "fmt" "runtime" + "strings" + + "github.com/matrix-org/util" ) // ErrUserExists is returned if a username already exists in the database. @@ -107,3 +111,44 @@ func SQLiteDriverName() string { } return "sqlite3" } + +func minOfInts(a, b int) int { + if a <= b { + return a + } + return b +} + +// QueryProvider defines the interface for querys used by RunLimitedVariablesQuery. +type QueryProvider interface { + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) +} + +// SQLite3MaxVariables is the default maximum number of host parameters in a single SQL statement +// SQLlite can handle. See https://www.sqlite.org/limits.html for more information. +const SQLite3MaxVariables = 999 + +// RunLimitedVariablesQuery split up a query with more variables than the used database can handle in multiple queries. +func RunLimitedVariablesQuery(ctx context.Context, query string, qp QueryProvider, variables []interface{}, limit uint, rowHandler func(*sql.Rows) error) error { + var start int + for start < len(variables) { + n := minOfInts(len(variables)-start, int(limit)) + nextQuery := strings.Replace(query, "($1)", QueryVariadic(n), 1) + rows, err := qp.QueryContext(ctx, nextQuery, variables[start:start+n]...) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("QueryContext returned an error") + return err + } + err = rowHandler(rows) + if closeErr := rows.Close(); closeErr != nil { + util.GetLogger(ctx).WithError(closeErr).Error("RunLimitedVariablesQuery: failed to close rows") + return err + } + if err != nil { + util.GetLogger(ctx).WithError(err).Error("RunLimitedVariablesQuery: rowHandler returned error") + return err + } + start = start + n + } + return nil +} diff --git a/internal/sqlutil/sqlutil_test.go b/internal/sqlutil/sqlutil_test.go new file mode 100644 index 00000000..79469cdd --- /dev/null +++ b/internal/sqlutil/sqlutil_test.go @@ -0,0 +1,173 @@ +package sqlutil + +import ( + "context" + "database/sql" + "reflect" + "testing" + + sqlmock "github.com/DATA-DOG/go-sqlmock" +) + +func TestShouldReturnCorrectAmountOfResulstIfFewerVariablesThanLimit(t *testing.T) { + db, mock, err := sqlmock.New() + assertNoError(t, err, "Failed to make DB") + limit := uint(4) + + r := mock.NewRows([]string{"id"}). + AddRow(1). + AddRow(2). + AddRow(3) + + mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3\)`).WillReturnRows(r) + // nolint:goconst + q := "SELECT id WHERE id IN ($1)" + v := []int{1, 2, 3} + iKeyIDs := make([]interface{}, len(v)) + for i, d := range v { + iKeyIDs[i] = d + } + + ctx := context.Background() + var result = make([]int, 0) + err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error { + for rows.Next() { + var id int + err = rows.Scan(&id) + assertNoError(t, err, "rows.Scan returned an error") + result = append(result, id) + } + return nil + }) + assertNoError(t, err, "Call returned an error") + if len(result) != len(v) { + t.Fatalf("Result should be 3 long") + } +} + +func TestShouldReturnCorrectAmountOfResulstIfEqualVariablesAsLimit(t *testing.T) { + db, mock, err := sqlmock.New() + assertNoError(t, err, "Failed to make DB") + limit := uint(4) + + r := mock.NewRows([]string{"id"}). + AddRow(1). + AddRow(2). + AddRow(3). + AddRow(4) + + mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3, \$4\)`).WillReturnRows(r) + // nolint:goconst + q := "SELECT id WHERE id IN ($1)" + v := []int{1, 2, 3, 4} + iKeyIDs := make([]interface{}, len(v)) + for i, d := range v { + iKeyIDs[i] = d + } + + ctx := context.Background() + var result = make([]int, 0) + err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error { + for rows.Next() { + var id int + err = rows.Scan(&id) + assertNoError(t, err, "rows.Scan returned an error") + result = append(result, id) + } + return nil + }) + assertNoError(t, err, "Call returned an error") + if len(result) != len(v) { + t.Fatalf("Result should be 4 long") + } +} + +func TestShouldReturnCorrectAmountOfResultsIfMoreVariablesThanLimit(t *testing.T) { + db, mock, err := sqlmock.New() + assertNoError(t, err, "Failed to make DB") + limit := uint(4) + + r1 := mock.NewRows([]string{"id"}). + AddRow(1). + AddRow(2). + AddRow(3). + AddRow(4) + + r2 := mock.NewRows([]string{"id"}). + AddRow(5) + + mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3, \$4\)`).WillReturnRows(r1) + mock.ExpectQuery(`SELECT id WHERE id IN \(\$1\)`).WillReturnRows(r2) + // nolint:goconst + q := "SELECT id WHERE id IN ($1)" + v := []int{1, 2, 3, 4, 5} + iKeyIDs := make([]interface{}, len(v)) + for i, d := range v { + iKeyIDs[i] = d + } + + ctx := context.Background() + var result = make([]int, 0) + err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error { + for rows.Next() { + var id int + err = rows.Scan(&id) + assertNoError(t, err, "rows.Scan returned an error") + result = append(result, id) + } + return nil + }) + assertNoError(t, err, "Call returned an error") + if len(result) != len(v) { + t.Fatalf("Result should be 5 long") + } + if !reflect.DeepEqual(v, result) { + t.Fatalf("Result is not as expected: got %v want %v", v, result) + } +} + +func TestShouldReturnErrorIfRowsScanReturnsError(t *testing.T) { + db, mock, err := sqlmock.New() + assertNoError(t, err, "Failed to make DB") + limit := uint(4) + + // adding a string ID should result in rows.Scan returning an error + r := mock.NewRows([]string{"id"}). + AddRow("hej"). + AddRow(2). + AddRow(3) + + mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3\)`).WillReturnRows(r) + // nolint:goconst + q := "SELECT id WHERE id IN ($1)" + v := []int{-1, -2, 3} + iKeyIDs := make([]interface{}, len(v)) + for i, d := range v { + iKeyIDs[i] = d + } + + ctx := context.Background() + var result = make([]uint, 0) + err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error { + for rows.Next() { + var id uint + err = rows.Scan(&id) + if err != nil { + return err + } + result = append(result, id) + } + return nil + }) + if err == nil { + t.Fatalf("Call did not return an error") + } +} + +func assertNoError(t *testing.T, err error, msg string) { + t.Helper() + if err == nil { + return + } + t.Fatalf(msg) +} diff --git a/serverkeyapi/storage/sqlite3/server_key_table.go b/serverkeyapi/storage/sqlite3/server_key_table.go index f756ef5e..2484d636 100644 --- a/serverkeyapi/storage/sqlite3/server_key_table.go +++ b/serverkeyapi/storage/sqlite3/server_key_table.go @@ -18,9 +18,8 @@ package sqlite3 import ( "context" "database/sql" - "strings" + "fmt" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/gomatrixserverlib" ) @@ -88,48 +87,50 @@ func (s *serverKeyStatements) bulkSelectServerKeys( ctx context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { - var nameAndKeyIDs []string + nameAndKeyIDs := make([]string, 0, len(requests)) for request := range requests { nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request)) } - - query := strings.Replace(bulkSelectServerKeysSQL, "($1)", sqlutil.QueryVariadic(len(nameAndKeyIDs)), 1) - + results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, len(requests)) iKeyIDs := make([]interface{}, len(nameAndKeyIDs)) for i, v := range nameAndKeyIDs { iKeyIDs[i] = v } - rows, err := s.db.QueryContext(ctx, query, iKeyIDs...) + err := sqlutil.RunLimitedVariablesQuery( + ctx, bulkSelectServerKeysSQL, s.db, iKeyIDs, sqlutil.SQLite3MaxVariables, + func(rows *sql.Rows) error { + for rows.Next() { + var serverName string + var keyID string + var key string + var validUntilTS int64 + var expiredTS int64 + if err := rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil { + return fmt.Errorf("bulkSelectServerKeys: %v", err) + } + r := gomatrixserverlib.PublicKeyLookupRequest{ + ServerName: gomatrixserverlib.ServerName(serverName), + KeyID: gomatrixserverlib.KeyID(keyID), + } + vk := gomatrixserverlib.VerifyKey{} + err := vk.Key.Decode(key) + if err != nil { + return fmt.Errorf("bulkSelectServerKeys: %v", err) + } + results[r] = gomatrixserverlib.PublicKeyLookupResult{ + VerifyKey: vk, + ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS), + ExpiredTS: gomatrixserverlib.Timestamp(expiredTS), + } + } + return nil + }, + ) + if err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectServerKeys: rows.close() failed") - results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{} - for rows.Next() { - var serverName string - var keyID string - var key string - var validUntilTS int64 - var expiredTS int64 - if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil { - return nil, err - } - r := gomatrixserverlib.PublicKeyLookupRequest{ - ServerName: gomatrixserverlib.ServerName(serverName), - KeyID: gomatrixserverlib.KeyID(keyID), - } - vk := gomatrixserverlib.VerifyKey{} - err = vk.Key.Decode(key) - if err != nil { - return nil, err - } - results[r] = gomatrixserverlib.PublicKeyLookupResult{ - VerifyKey: vk, - ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS), - ExpiredTS: gomatrixserverlib.Timestamp(expiredTS), - } - } return results, nil } From 965f068d1a6298b2ec733b0df983773a6ec8b622 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 15 Sep 2020 11:17:46 +0100 Subject: [PATCH 02/12] Handle state with input event as new events (#1415) * SendEventWithState events as new * Use cumulative state IDs for final event * Error wrapping in calculateAndSetState * Handle overwriting same event type and state key * Hacky way to spot historical events * Don't exclude from sync * Don't generate output events when rewriting forward extremities * Update output event check * Historical output events * Define output room event type * Notify key changes on state * Don't send our membership event twice * Deduplicate state entries * Tweaks * Remove unnecessary nolint * Fix current state upsert in sync API * Send auth events as outliers, state events as rewrite * Sync API don't consume state events * Process events actually * Improve outlier check * Fix local room check * Remove extra room check, it seems to break the whole damn world * Fix federated join check * Fix nil pointer exception * Better comments on DeduplicateStateEntries * Reflow forced federated joins * Don't force federated join for possibly even local invites * Comment SendEventWithState better * Rewrite room state in sync API storage * Add TODO * Clean up all room data when receiving create event * Don't generate output events for rewrites, but instead notify that state is rewritten on the final new event * Rename to PurgeRoom * Exclude backfilled messages from /sync * Split out rewriting state from updating state from state res Co-authored-by: Kegan Dougal --- federationsender/internal/perform.go | 7 +- roomserver/api/input.go | 4 + roomserver/api/output.go | 14 + roomserver/api/wrapper.go | 101 ++++++++ roomserver/internal/helpers/auth.go | 2 +- roomserver/internal/input/input_events.go | 26 +- .../internal/input/input_latest_events.go | 12 +- roomserver/internal/perform/perform_join.go | 30 +-- roomserver/roomserver_test.go | 245 +++++++++++++++++- roomserver/types/types.go | 21 ++ roomserver/types/types_test.go | 26 ++ syncapi/consumers/roomserver.go | 6 + syncapi/storage/interface.go | 3 + .../postgres/backwards_extremities_table.go | 15 ++ .../postgres/current_room_state_table.go | 15 ++ .../postgres/output_room_events_table.go | 14 + .../output_room_events_topology_table.go | 15 ++ syncapi/storage/shared/syncserver.go | 23 ++ .../sqlite3/backwards_extremities_table.go | 15 ++ .../sqlite3/current_room_state_table.go | 17 +- .../sqlite3/output_room_events_table.go | 14 + .../output_room_events_topology_table.go | 14 + syncapi/storage/tables/interface.go | 7 + 23 files changed, 616 insertions(+), 30 deletions(-) create mode 100644 roomserver/types/types_test.go diff --git a/federationsender/internal/perform.go b/federationsender/internal/perform.go index 90abae23..a0abf7ff 100644 --- a/federationsender/internal/perform.go +++ b/federationsender/internal/perform.go @@ -98,7 +98,10 @@ func (r *FederationSenderInternalAPI) PerformJoin( response.LastError = &gomatrix.HTTPError{ Code: 0, WrappedError: nil, - Message: lastErr.Error(), + Message: "Unknown HTTP error", + } + if lastErr != nil { + response.LastError.Message = lastErr.Error() } } @@ -195,7 +198,7 @@ func (r *FederationSenderInternalAPI) performJoinUsingServer( // If we successfully performed a send_join above then the other // server now thinks we're a part of the room. Send the newly // returned state to the roomserver to update our local view. - if err = roomserverAPI.SendEventWithState( + if err = roomserverAPI.SendEventWithRewrite( ctx, r.rsAPI, respState, event.Headered(respMakeJoin.RoomVersion), diff --git a/roomserver/api/input.go b/roomserver/api/input.go index 73c4994a..651c0e9f 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -33,6 +33,10 @@ const ( // KindBackfill event extend the contiguous graph going backwards. // They always have state. KindBackfill = 3 + // KindRewrite events are used when rewriting the head of the room + // graph with entirely new state. The output events generated will + // be state events rather than timeline events. + KindRewrite = 4 ) // DoNotSendToOtherServers tells us not to send the event to other matrix diff --git a/roomserver/api/output.go b/roomserver/api/output.go index 013ebdc8..d57f3b04 100644 --- a/roomserver/api/output.go +++ b/roomserver/api/output.go @@ -68,6 +68,17 @@ type OutputEvent struct { NewPeek *OutputNewPeek `json:"new_peek,omitempty"` } +// Type of the OutputNewRoomEvent. +type OutputRoomEventType int + +const ( + // The event is a timeline event and likely just happened. + OutputRoomTimeline OutputRoomEventType = iota + + // The event is a state event and quite possibly happened in the past. + OutputRoomState +) + // An OutputNewRoomEvent is written when the roomserver receives a new event. // It contains the full matrix room event and enough information for a // consumer to construct the current state of the room and the state before the @@ -80,6 +91,9 @@ type OutputEvent struct { type OutputNewRoomEvent struct { // The Event. Event gomatrixserverlib.HeaderedEvent `json:"event"` + // Does the event completely rewrite the room state? If so, then AddsStateEventIDs + // will contain the entire room state. + RewritesState bool `json:"rewrites_state"` // The latest events in the room after this event. // This can be used to set the prev events for new events in the room. // This also can be used to get the full current state after this event. diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 82a4a571..e5339311 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -80,6 +80,107 @@ func SendEventWithState( return SendInputRoomEvents(ctx, rsAPI, ires) } +// SendEventWithRewrite writes an event with KindNew to the roomserver along +// with a number of rewrite and outlier events for state and auth events +// respectively. +func SendEventWithRewrite( + ctx context.Context, rsAPI RoomserverInternalAPI, state *gomatrixserverlib.RespState, + event gomatrixserverlib.HeaderedEvent, haveEventIDs map[string]bool, +) error { + isCurrentState := map[string]struct{}{} + for _, se := range state.StateEvents { + isCurrentState[se.EventID()] = struct{}{} + } + + authAndStateEvents, err := state.Events() + if err != nil { + return err + } + + var ires []InputRoomEvent + var stateIDs []string + + // This function generates three things: + // A - A set of "rewrite" events, which will form the newly rewritten + // state before the event, which includes every rewrite event that + // came before it in its state + // B - A set of "outlier" events, which are auth events but not part + // of the rewritten state + // C - A "new" event, which include all of the rewrite events in its + // state + for _, authOrStateEvent := range authAndStateEvents { + if authOrStateEvent.StateKey() == nil { + continue + } + if haveEventIDs[authOrStateEvent.EventID()] { + continue + } + if event.StateKey() == nil { + continue + } + + // We will handle an event as if it's an outlier if one of the + // following conditions is true: + storeAsOutlier := false + if authOrStateEvent.Type() == event.Type() && *authOrStateEvent.StateKey() == *event.StateKey() { + // The event is a state event but the input event is going to + // replace it, therefore it can't be added to the state or we'll + // get duplicate state keys in the state block. We'll send it + // as an outlier because we don't know if something will be + // referring to it as an auth event, but need it to be stored + // just in case. + storeAsOutlier = true + } else if _, ok := isCurrentState[authOrStateEvent.EventID()]; !ok { + // The event is an auth event and isn't a part of the state set. + // We'll send it as an outlier because we need it to be stored + // in case something is referring to it as an auth event. + storeAsOutlier = true + } + + if storeAsOutlier { + ires = append(ires, InputRoomEvent{ + Kind: KindOutlier, + Event: authOrStateEvent.Headered(event.RoomVersion), + AuthEventIDs: authOrStateEvent.AuthEventIDs(), + }) + continue + } + + // If the event isn't an outlier then we'll instead send it as a + // rewrite event, so that it'll form part of the rewritten state. + // These events will go through the membership and latest event + // updaters and we will generate output events, but they will be + // flagged as non-current (i.e. didn't just happen) events. + // Each of these rewrite events includes all of the rewrite events + // that came before in their StateEventIDs. + ires = append(ires, InputRoomEvent{ + Kind: KindRewrite, + Event: authOrStateEvent.Headered(event.RoomVersion), + AuthEventIDs: authOrStateEvent.AuthEventIDs(), + HasState: true, + StateEventIDs: stateIDs, + }) + + // Add the event ID into the StateEventIDs of all subsequent + // rewrite events, and the new event. + stateIDs = append(stateIDs, authOrStateEvent.EventID()) + } + + // Send the final event as a new event, which will generate + // a timeline output event for it. All of the rewrite events + // that came before will be sent as StateEventIDs, forming a + // new clean state before the event. + ires = append(ires, InputRoomEvent{ + Kind: KindNew, + Event: event, + AuthEventIDs: event.AuthEventIDs(), + HasState: true, + StateEventIDs: stateIDs, + }) + + return SendInputRoomEvents(ctx, rsAPI, ires) +} + // SendInputRoomEvents to the roomserver. func SendInputRoomEvents( ctx context.Context, rsAPI RoomserverInternalAPI, ires []InputRoomEvent, diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 060f0a0e..524a5451 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -36,7 +36,7 @@ func CheckAuthEvents( if err != nil { return nil, err } - // TODO: check for duplicate state keys here. + authStateEntries = types.DeduplicateStateEntries(authStateEntries) // Work out which of the state events we actually need. stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()}) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 6ee679da..daf1afcd 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -86,7 +86,7 @@ func (r *Inputer) processRoomEvent( "event_id": event.EventID(), "type": event.Type(), "room": event.RoomID(), - }).Info("Stored outlier") + }).Debug("Stored outlier") return event.EventID(), nil } @@ -107,6 +107,15 @@ func (r *Inputer) processRoomEvent( } } + if input.Kind == api.KindRewrite { + logrus.WithFields(logrus.Fields{ + "event_id": event.EventID(), + "type": event.Type(), + "room": event.RoomID(), + }).Debug("Stored rewrite") + return event.EventID(), nil + } + if err = r.updateLatestEvents( ctx, // context roomInfo, // room info for the room being updated @@ -114,6 +123,7 @@ func (r *Inputer) processRoomEvent( event, // event input.SendAsServer, // send as server input.TransactionID, // transaction ID + input.HasState, // rewrites state? ); err != nil { return "", fmt.Errorf("r.updateLatestEvents: %w", err) } @@ -167,19 +177,25 @@ func (r *Inputer) calculateAndSetState( // Check that those state events are in the database and store the state. var entries []types.StateEntry if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { - return err + return fmt.Errorf("r.DB.StateEntriesForEventIDs: %w", err) } + entries = types.DeduplicateStateEntries(entries) if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { - return err + return fmt.Errorf("r.DB.AddState: %w", err) } } else { stateAtEvent.Overwrite = false // We haven't been told what the state at the event is so we need to calculate it from the prev_events if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event); err != nil { - return err + return fmt.Errorf("roomState.CalculateAndStoreStateBeforeEvent: %w", err) } } - return r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) + + err = r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID) + if err != nil { + return fmt.Errorf("r.DB.SetState: %w", err) + } + return nil } diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index 67a7d8a4..5c2a1de6 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -54,6 +54,7 @@ func (r *Inputer) updateLatestEvents( event gomatrixserverlib.Event, sendAsServer string, transactionID *api.TransactionID, + rewritesState bool, ) (err error) { updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo) if err != nil { @@ -71,6 +72,7 @@ func (r *Inputer) updateLatestEvents( event: event, sendAsServer: sendAsServer, transactionID: transactionID, + rewritesState: rewritesState, } if err = u.doUpdateLatestEvents(); err != nil { @@ -93,6 +95,7 @@ type latestEventsUpdater struct { stateAtEvent types.StateAtEvent event gomatrixserverlib.Event transactionID *api.TransactionID + rewritesState bool // Which server to send this event as. sendAsServer string // The eventID of the event that was processed before this one. @@ -178,7 +181,8 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { return fmt.Errorf("u.api.updateMemberships: %w", err) } - update, err := u.makeOutputNewRoomEvent() + var update *api.OutputEvent + update, err = u.makeOutputNewRoomEvent() if err != nil { return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err) } @@ -305,6 +309,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) ore := api.OutputNewRoomEvent{ Event: u.event.Headered(u.roomInfo.RoomVersion), + RewritesState: u.rewritesState, LastSentEventID: u.lastEventIDSent, LatestEventIDs: latestEventIDs, TransactionID: u.transactionID, @@ -337,6 +342,11 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) return nil, fmt.Errorf("failed to load add_state_events from db: %w", err) } } + // State is rewritten if the input room event HasState and we actually produced a delta on state events. + // Without this check, /get_missing_events which produce events with associated (but not complete) state + // will incorrectly purge the room and set it to no state. TODO: This is likely flakey, as if /gme produced + // a state conflict res which just so happens to include 2+ events we might purge the room state downstream. + ore.RewritesState = len(ore.AddsStateEventIDs) > 1 return &api.OutputEvent{ Type: api.OutputTypeNewRoomEvent, diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 3d194227..f76806c7 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -183,33 +183,33 @@ func (r *Joiner) performJoinRoomByID( return "", fmt.Errorf("eb.SetContent: %w", err) } - // First work out if this is in response to an existing invite - // from a federated server. If it is then we avoid the situation - // where we might think we know about a room in the following - // section but don't know the latest state as all of our users - // have left. + // Force a federated join if we aren't in the room and we've been + // given some server names to try joining by. serverInRoom, _ := helpers.IsServerCurrentlyInRoom(ctx, r.DB, r.ServerName, req.RoomIDOrAlias) + forceFederatedJoin := len(req.ServerNames) > 0 && !serverInRoom + + // Force a federated join if we're dealing with a pending invite + // and we aren't in the room. isInvitePending, inviteSender, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, req.UserID) - if err == nil && isInvitePending && !serverInRoom { - // Check if there's an invite pending. + if err == nil && isInvitePending { _, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender) if ierr != nil { return "", fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } - // Check that the domain isn't ours. If it's local then we don't - // need to do anything as our own copy of the room state will be - // up-to-date. + // If we were invited by someone from another server then we can + // assume they are in the room so we can join via them. if inviterDomain != r.Cfg.Matrix.ServerName { - // Add the server of the person who invited us to the server list, - // as they should be a fairly good bet. req.ServerNames = append(req.ServerNames, inviterDomain) - - // Perform a federated room join. - return req.RoomIDOrAlias, r.performFederatedJoinRoomByID(ctx, req) + forceFederatedJoin = true } } + // If we should do a forced federated join then do that. + if forceFederatedJoin { + return req.RoomIDOrAlias, r.performFederatedJoinRoomByID(ctx, req) + } + // Try to construct an actual join event from the template. // If this succeeds then it is a sign that the room already exists // locally on the homeserver. diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 786d4f31..5a67a1be 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -1,12 +1,15 @@ package roomserver import ( + "bytes" "context" + "crypto/ed25519" "encoding/json" "fmt" "os" "reflect" "testing" + "time" "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/internal/caching" @@ -80,7 +83,73 @@ func deleteDatabase() { } } -func mustLoadEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) []gomatrixserverlib.HeaderedEvent { +type fledglingEvent struct { + Type string + StateKey *string + Content interface{} + Sender string + RoomID string +} + +func mustCreateEvents(t *testing.T, roomVer gomatrixserverlib.RoomVersion, events []fledglingEvent) (result []gomatrixserverlib.HeaderedEvent) { + t.Helper() + depth := int64(1) + seed := make([]byte, ed25519.SeedSize) // zero seed + key := ed25519.NewKeyFromSeed(seed) + var prevs []string + roomState := make(map[gomatrixserverlib.StateKeyTuple]string) // state -> event ID + for _, ev := range events { + eb := gomatrixserverlib.EventBuilder{ + Sender: ev.Sender, + Depth: depth, + Type: ev.Type, + StateKey: ev.StateKey, + RoomID: ev.RoomID, + PrevEvents: prevs, + } + err := eb.SetContent(ev.Content) + if err != nil { + t.Fatalf("mustCreateEvent: failed to marshal event content %+v", ev.Content) + } + stateNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(&eb) + if err != nil { + t.Fatalf("mustCreateEvent: failed to work out auth_events : %s", err) + } + var authEvents []string + for _, tuple := range stateNeeded.Tuples() { + eventID := roomState[tuple] + if eventID != "" { + authEvents = append(authEvents, eventID) + } + } + eb.AuthEvents = authEvents + signedEvent, err := eb.Build(time.Now(), testOrigin, "ed25519:test", key, roomVer) + if err != nil { + t.Fatalf("mustCreateEvent: failed to sign event: %s", err) + } + depth++ + prevs = []string{signedEvent.EventID()} + if ev.StateKey != nil { + roomState[gomatrixserverlib.StateKeyTuple{ + EventType: ev.Type, + StateKey: *ev.StateKey, + }] = signedEvent.EventID() + } + result = append(result, signedEvent.Headered(roomVer)) + } + return +} + +func eventsJSON(events []gomatrixserverlib.Event) []json.RawMessage { + result := make([]json.RawMessage, len(events)) + for i := range events { + result[i] = events[i].JSON() + } + return result +} + +func mustLoadRawEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) []gomatrixserverlib.HeaderedEvent { + t.Helper() hs := make([]gomatrixserverlib.HeaderedEvent, len(events)) for i := range events { e, err := gomatrixserverlib.NewEventFromTrustedJSON(events[i], false, ver) @@ -93,7 +162,8 @@ func mustLoadEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []js return hs } -func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) (api.RoomserverInternalAPI, *dummyProducer, []gomatrixserverlib.HeaderedEvent) { +func mustCreateRoomserverAPI(t *testing.T) (api.RoomserverInternalAPI, *dummyProducer) { + t.Helper() cfg := &config.Dendrite{} cfg.Defaults() cfg.Global.ServerName = testOrigin @@ -112,9 +182,14 @@ func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []js Cfg: cfg, } - rsAPI := NewInternalAPI(base, &test.NopJSONVerifier{}) - hevents := mustLoadEvents(t, ver, events) - if err = api.SendEvents(ctx, rsAPI, hevents, testOrigin, nil); err != nil { + return NewInternalAPI(base, &test.NopJSONVerifier{}), dp +} + +func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) (api.RoomserverInternalAPI, *dummyProducer, []gomatrixserverlib.HeaderedEvent) { + t.Helper() + rsAPI, dp := mustCreateRoomserverAPI(t) + hevents := mustLoadRawEvents(t, ver, events) + if err := api.SendEvents(ctx, rsAPI, hevents, testOrigin, nil); err != nil { t.Errorf("failed to SendEvents: %s", err) } return rsAPI, dp, hevents @@ -170,3 +245,163 @@ func TestOutputRedactedEvent(t *testing.T) { } } } + +// This tests that rewriting state via KindRewrite works correctly. +// This creates a small room with a create/join/name state, then replays it +// with a new room name. We expect the output events to contain the original events, +// followed by a single OutputNewRoomEvent with RewritesState set to true with the +// rewritten state events (with the 2nd room name). +func TestOutputRewritesState(t *testing.T) { + roomID := "!foo:" + string(testOrigin) + alice := "@alice:" + string(testOrigin) + emptyKey := "" + originalEvents := mustCreateEvents(t, gomatrixserverlib.RoomVersionV6, []fledglingEvent{ + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "creator": alice, + "room_version": "6", + }, + StateKey: &emptyKey, + Type: gomatrixserverlib.MRoomCreate, + }, + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "membership": "join", + }, + StateKey: &alice, + Type: gomatrixserverlib.MRoomMember, + }, + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "body": "hello world", + }, + StateKey: nil, + Type: "m.room.message", + }, + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "name": "Room Name", + }, + StateKey: &emptyKey, + Type: "m.room.name", + }, + }) + rewriteEvents := mustCreateEvents(t, gomatrixserverlib.RoomVersionV6, []fledglingEvent{ + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "creator": alice, + }, + StateKey: &emptyKey, + Type: gomatrixserverlib.MRoomCreate, + }, + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "membership": "join", + }, + StateKey: &alice, + Type: gomatrixserverlib.MRoomMember, + }, + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "name": "Room Name 2", + }, + StateKey: &emptyKey, + Type: "m.room.name", + }, + { + RoomID: roomID, + Sender: alice, + Content: map[string]interface{}{ + "body": "hello world 2", + }, + StateKey: nil, + Type: "m.room.message", + }, + }) + deleteDatabase() + rsAPI, producer := mustCreateRoomserverAPI(t) + defer deleteDatabase() + err := api.SendEvents(context.Background(), rsAPI, originalEvents, testOrigin, nil) + if err != nil { + t.Fatalf("failed to send original events: %s", err) + } + // assert we got them produced, this is just a sanity check and isn't the intention of this test + if len(producer.producedMessages) != len(originalEvents) { + t.Fatalf("SendEvents didn't result in same number of produced output events: got %d want %d", len(producer.producedMessages), len(originalEvents)) + } + producer.producedMessages = nil // we aren't actually interested in these events, just the rewrite ones + + var inputEvents []api.InputRoomEvent + // slowly build up the state IDs again, we're basically telling the roomserver what to store as a snapshot + var stateIDs []string + // skip the last event, we'll use this to tie together the rewrite as the KindNew event + for i := 0; i < len(rewriteEvents)-1; i++ { + ev := rewriteEvents[i] + inputEvents = append(inputEvents, api.InputRoomEvent{ + Kind: api.KindRewrite, + Event: ev, + AuthEventIDs: ev.AuthEventIDs(), + HasState: true, + StateEventIDs: stateIDs, + }) + if ev.StateKey() != nil { + stateIDs = append(stateIDs, ev.EventID()) + } + } + lastEv := rewriteEvents[len(rewriteEvents)-1] + inputEvents = append(inputEvents, api.InputRoomEvent{ + Kind: api.KindNew, + Event: lastEv, + AuthEventIDs: lastEv.AuthEventIDs(), + HasState: true, + StateEventIDs: stateIDs, + }) + if err := api.SendInputRoomEvents(context.Background(), rsAPI, inputEvents); err != nil { + t.Fatalf("SendInputRoomEvents returned error for rewrite events: %s", err) + } + // we should just have one output event with the entire state of the room in it + if len(producer.producedMessages) != 1 { + t.Fatalf("Rewritten events got output, want only 1 got %d", len(producer.producedMessages)) + } + outputEvent := producer.producedMessages[0] + if !outputEvent.NewRoomEvent.RewritesState { + t.Errorf("RewritesState flag not set on output event") + } + if !reflect.DeepEqual(stateIDs, outputEvent.NewRoomEvent.AddsStateEventIDs) { + t.Errorf("Output event is missing room state event IDs, got %v want %v", outputEvent.NewRoomEvent.AddsStateEventIDs, stateIDs) + } + if !bytes.Equal(outputEvent.NewRoomEvent.Event.JSON(), lastEv.JSON()) { + t.Errorf( + "Output event isn't the latest KindNew event:\ngot %s\nwant %s", + string(outputEvent.NewRoomEvent.Event.JSON()), + string(lastEv.JSON()), + ) + } + if len(outputEvent.NewRoomEvent.AddStateEvents) != len(stateIDs) { + t.Errorf("Output event is missing room state events themselves, got %d want %d", len(outputEvent.NewRoomEvent.AddStateEvents), len(stateIDs)) + } + // make sure the state got overwritten, check the room name + hasRoomName := false + for _, ev := range outputEvent.NewRoomEvent.AddStateEvents { + if ev.Type() == "m.room.name" { + hasRoomName = string(ev.Content()) == `{"name":"Room Name 2"}` + } + } + if !hasRoomName { + t.Errorf("Output event did not overwrite room state") + } +} diff --git a/roomserver/types/types.go b/roomserver/types/types.go index 60f4b0fd..f5b45763 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -16,6 +16,8 @@ package types import ( + "sort" + "github.com/matrix-org/gomatrixserverlib" ) @@ -72,6 +74,25 @@ func (a StateEntry) LessThan(b StateEntry) bool { return a.EventNID < b.EventNID } +// Deduplicate takes a set of state entries and ensures that there are no +// duplicate (event type, state key) tuples. If there are then we dedupe +// them, making sure that the latest/highest NIDs are always chosen. +func DeduplicateStateEntries(a []StateEntry) []StateEntry { + if len(a) < 2 { + return a + } + sort.SliceStable(a, func(i, j int) bool { + return a[i].LessThan(a[j]) + }) + for i := 0; i < len(a)-1; i++ { + if a[i].StateKeyTuple == a[i+1].StateKeyTuple { + a = append(a[:i], a[i+1:]...) + i-- + } + } + return a +} + // StateAtEvent is the state before and after a matrix event. type StateAtEvent struct { // Should this state overwrite the latest events and memberships of the room? diff --git a/roomserver/types/types_test.go b/roomserver/types/types_test.go new file mode 100644 index 00000000..b1e84b82 --- /dev/null +++ b/roomserver/types/types_test.go @@ -0,0 +1,26 @@ +package types + +import ( + "testing" +) + +func TestDeduplicateStateEntries(t *testing.T) { + entries := []StateEntry{ + {StateKeyTuple{1, 1}, 1}, + {StateKeyTuple{1, 1}, 2}, + {StateKeyTuple{1, 1}, 3}, + {StateKeyTuple{2, 2}, 4}, + {StateKeyTuple{2, 3}, 5}, + {StateKeyTuple{3, 3}, 6}, + } + expected := []EventNID{3, 4, 5, 6} + entries = DeduplicateStateEntries(entries) + if len(entries) != 4 { + t.Fatalf("Expected 4 entries, got %d entries", len(entries)) + } + for i, v := range entries { + if v.EventNID != expected[i] { + t.Fatalf("Expected position %d to be %d but got %d", i, expected[i], v.EventNID) + } + } +} diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index b6ab9bd5..d8d0a298 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -149,6 +149,12 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent( } } + if msg.RewritesState { + if err = s.db.PurgeRoom(ctx, ev.RoomID()); err != nil { + return fmt.Errorf("s.db.PurgeRoom: %w", err) + } + } + pduPos, err := s.db.WriteEvent( ctx, &ev, diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 807c7f5e..ce7f1c15 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -43,6 +43,9 @@ type Database interface { // Returns an error if there was a problem inserting this event. WriteEvent(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, addStateEvents []gomatrixserverlib.HeaderedEvent, addStateEventIDs []string, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool) (types.StreamPosition, error) + // PurgeRoom completely purges room state from the sync API. This is done when + // receiving an output event that completely resets the state. + PurgeRoom(ctx context.Context, roomID string) error // GetStateEvent returns the Matrix state event of a given type for a given room with a given state key // If no event could be found, returns nil // If there was an issue during the retrieval, returns an error diff --git a/syncapi/storage/postgres/backwards_extremities_table.go b/syncapi/storage/postgres/backwards_extremities_table.go index 71569a10..13056588 100644 --- a/syncapi/storage/postgres/backwards_extremities_table.go +++ b/syncapi/storage/postgres/backwards_extremities_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" ) @@ -46,10 +47,14 @@ const selectBackwardExtremitiesForRoomSQL = "" + const deleteBackwardExtremitySQL = "" + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" +const deleteBackwardExtremitiesForRoomSQL = "" + + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1" + type backwardExtremitiesStatements struct { insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt + deleteBackwardExtremitiesForRoomStmt *sql.Stmt } func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { @@ -67,6 +72,9 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { return nil, err } + if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil { + return nil, err + } return s, nil } @@ -105,3 +113,10 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( _, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) return } + +func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 5cb7baad..0ca9eed9 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -69,6 +69,9 @@ const upsertRoomStateSQL = "" + const deleteRoomStateByEventIDSQL = "" + "DELETE FROM syncapi_current_room_state WHERE event_id = $1" +const DeleteRoomStateForRoomSQL = "" + + "DELETE FROM syncapi_current_room_state WHERE event_id = $1" + const selectRoomIDsWithMembershipSQL = "" + "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" @@ -98,6 +101,7 @@ const selectEventsWithEventIDsSQL = "" + type currentRoomStateStatements struct { upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt + DeleteRoomStateForRoomStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt selectCurrentStateStmt *sql.Stmt selectJoinedUsersStmt *sql.Stmt @@ -117,6 +121,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { return nil, err } + if s.DeleteRoomStateForRoomStmt, err = db.Prepare(DeleteRoomStateForRoomSQL); err != nil { + return nil, err + } if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { return nil, err } @@ -214,6 +221,14 @@ func (s *currentRoomStateStatements) DeleteRoomStateByEventID( return err } +func (s *currentRoomStateStatements) DeleteRoomStateForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + stmt := sqlutil.TxStmt(txn, s.DeleteRoomStateForRoomStmt) + _, err := stmt.ExecContext(ctx, roomID) + return err +} + func (s *currentRoomStateStatements) UpsertRoomState( ctx context.Context, txn *sql.Tx, event gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition, diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 5315de24..4b2101bb 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -115,6 +115,9 @@ const selectStateInRangeSQL = "" + " ORDER BY id ASC" + " LIMIT $8" +const deleteEventsForRoomSQL = "" + + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" + type outputRoomEventsStatements struct { insertEventStmt *sql.Stmt selectEventsStmt *sql.Stmt @@ -124,6 +127,7 @@ type outputRoomEventsStatements struct { selectEarlyEventsStmt *sql.Stmt selectStateInRangeStmt *sql.Stmt updateEventJSONStmt *sql.Stmt + deleteEventsForRoomStmt *sql.Stmt } func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { @@ -156,6 +160,9 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { if s.updateEventJSONStmt, err = db.Prepare(updateEventJSONSQL); err != nil { return nil, err } + if s.deleteEventsForRoomStmt, err = db.Prepare(deleteEventsForRoomSQL); err != nil { + return nil, err + } return s, nil } @@ -395,6 +402,13 @@ func (s *outputRoomEventsStatements) SelectEvents( return rowsToStreamEvents(rows) } +func (s *outputRoomEventsStatements) DeleteEventsForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.deleteEventsForRoomStmt).ExecContext(ctx, roomID) + return err +} + func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { var result []types.StreamEvent for rows.Next() { diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index 1ab3a1dc..cbd20a07 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -71,12 +72,16 @@ const selectMaxPositionInTopologySQL = "" + "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" + ") ORDER BY stream_position DESC LIMIT 1" +const deleteTopologyForRoomSQL = "" + + "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" + type outputRoomEventsTopologyStatements struct { insertEventInTopologyStmt *sql.Stmt selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt selectMaxPositionInTopologyStmt *sql.Stmt + deleteTopologyForRoomStmt *sql.Stmt } func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { @@ -100,6 +105,9 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { return nil, err } + if s.deleteTopologyForRoomStmt, err = db.Prepare(deleteTopologyForRoomSQL); err != nil { + return nil, err + } return s, nil } @@ -167,3 +175,10 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) return } + +func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 94580adb..05a8768e 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -276,6 +276,29 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e return nil } +func (d *Database) PurgeRoom( + ctx context.Context, roomID string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + // If the event is a create event then we'll delete all of the existing + // data for the room. The only reason that a create event would be replayed + // to us in this way is if we're about to receive the entire room state. + if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err) + } + if err := d.OutputEvents.DeleteEventsForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("d.Events.DeleteEventsForRoom: %w", err) + } + if err := d.Topology.DeleteTopologyForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("d.Topology.DeleteTopologyForRoom: %w", err) + } + if err := d.BackwardExtremities.DeleteBackwardExtremitiesForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("d.BackwardExtremities.DeleteBackwardExtremitiesForRoom: %w", err) + } + return nil + }) +} + func (d *Database) WriteEvent( ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go index 116c33dc..9a81e8e7 100644 --- a/syncapi/storage/sqlite3/backwards_extremities_table.go +++ b/syncapi/storage/sqlite3/backwards_extremities_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" ) @@ -46,11 +47,15 @@ const selectBackwardExtremitiesForRoomSQL = "" + const deleteBackwardExtremitySQL = "" + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" +const deleteBackwardExtremitiesForRoomSQL = "" + + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1" + type backwardExtremitiesStatements struct { db *sql.DB insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt + deleteBackwardExtremitiesForRoomStmt *sql.Stmt } func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { @@ -70,6 +75,9 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { return nil, err } + if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil { + return nil, err + } return s, nil } @@ -108,3 +116,10 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( _, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) return err } + +func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 6f822c90..13d23be5 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -51,12 +51,15 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_s const upsertRoomStateSQL = "" + "INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, headered_event_json, membership, added_at)" + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + - " ON CONFLICT (event_id, room_id, type, sender, contains_url)" + + " ON CONFLICT (room_id, type, state_key)" + " DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, headered_event_json = $7, membership = $8, added_at = $9" const deleteRoomStateByEventIDSQL = "" + "DELETE FROM syncapi_current_room_state WHERE event_id = $1" +const DeleteRoomStateForRoomSQL = "" + + "DELETE FROM syncapi_current_room_state WHERE event_id = $1" + const selectRoomIDsWithMembershipSQL = "" + "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" @@ -88,6 +91,7 @@ type currentRoomStateStatements struct { streamIDStatements *streamIDStatements upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt + DeleteRoomStateForRoomStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt selectCurrentStateStmt *sql.Stmt selectJoinedUsersStmt *sql.Stmt @@ -109,6 +113,9 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (t if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { return nil, err } + if s.DeleteRoomStateForRoomStmt, err = db.Prepare(DeleteRoomStateForRoomSQL); err != nil { + return nil, err + } if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { return nil, err } @@ -203,6 +210,14 @@ func (s *currentRoomStateStatements) DeleteRoomStateByEventID( return err } +func (s *currentRoomStateStatements) DeleteRoomStateForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + stmt := sqlutil.TxStmt(txn, s.DeleteRoomStateForRoomStmt) + _, err := stmt.ExecContext(ctx, roomID) + return err +} + func (s *currentRoomStateStatements) UpsertRoomState( ctx context.Context, txn *sql.Tx, event gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition, diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index f10d0106..587a4072 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -103,6 +103,9 @@ const selectStateInRangeSQL = "" + " ORDER BY id ASC" + " LIMIT $8" // limit +const deleteEventsForRoomSQL = "" + + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" + type outputRoomEventsStatements struct { db *sql.DB streamIDStatements *streamIDStatements @@ -114,6 +117,7 @@ type outputRoomEventsStatements struct { selectEarlyEventsStmt *sql.Stmt selectStateInRangeStmt *sql.Stmt updateEventJSONStmt *sql.Stmt + deleteEventsForRoomStmt *sql.Stmt } func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) { @@ -149,6 +153,9 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even if s.updateEventJSONStmt, err = db.Prepare(updateEventJSONSQL); err != nil { return nil, err } + if s.deleteEventsForRoomStmt, err = db.Prepare(deleteEventsForRoomSQL); err != nil { + return nil, err + } return s, nil } @@ -410,6 +417,13 @@ func (s *outputRoomEventsStatements) SelectEvents( return returnEvents, nil } +func (s *outputRoomEventsStatements) DeleteEventsForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.deleteEventsForRoomStmt).ExecContext(ctx, roomID) + return err +} + func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { var result []types.StreamEvent for rows.Next() { diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index d8c97b7e..d3ba9af6 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -65,6 +65,9 @@ const selectMaxPositionInTopologySQL = "" + "SELECT MAX(topological_position), stream_position FROM syncapi_output_room_events_topology" + " WHERE room_id = $1 ORDER BY stream_position DESC" +const deleteTopologyForRoomSQL = "" + + "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" + type outputRoomEventsTopologyStatements struct { db *sql.DB insertEventInTopologyStmt *sql.Stmt @@ -72,6 +75,7 @@ type outputRoomEventsTopologyStatements struct { selectEventIDsInRangeDESCStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt selectMaxPositionInTopologyStmt *sql.Stmt + deleteTopologyForRoomStmt *sql.Stmt } func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { @@ -97,6 +101,9 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { return nil, err } + if s.deleteTopologyForRoomStmt, err = db.Prepare(deleteTopologyForRoomSQL); err != nil { + return nil, err + } return s, nil } @@ -164,3 +171,10 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) return } + +func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom( + ctx context.Context, txn *sql.Tx, roomID string, +) (err error) { + _, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 631746c6..da095be5 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -60,6 +60,8 @@ type Events interface { SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int) ([]types.StreamEvent, error) SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error + // DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely. + DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) } // Topology keeps track of the depths and stream positions for all events. @@ -77,6 +79,8 @@ type Topology interface { SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error) // SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position. SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error) + // DeleteTopologyForRoom removes all topological information for a room. This should only be done when removing the room entirely. + DeleteTopologyForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) } type CurrentRoomState interface { @@ -84,6 +88,7 @@ type CurrentRoomState interface { SelectEventsWithEventIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error) UpsertRoomState(ctx context.Context, txn *sql.Tx, event gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition) error DeleteRoomStateByEventID(ctx context.Context, txn *sql.Tx, eventID string) error + DeleteRoomStateForRoom(ctx context.Context, txn *sql.Tx, roomID string) error // SelectCurrentState returns all the current state events for the given room. SelectCurrentState(ctx context.Context, txn *sql.Tx, roomID string, stateFilter *gomatrixserverlib.StateFilter) ([]gomatrixserverlib.HeaderedEvent, error) // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. @@ -118,6 +123,8 @@ type BackwardsExtremities interface { SelectBackwardExtremitiesForRoom(ctx context.Context, roomID string) (bwExtrems map[string][]string, err error) // DeleteBackwardExtremity removes a backwards extremity for a room, if one existed. DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error) + // DeleteBackwardExtremitiesFoorRoomID removes all backward extremities for a room. This should only be done when removing the room entirely. + DeleteBackwardExtremitiesForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) } // SendToDevice tracks send-to-device messages which are sent to individual From ba6c7c4a5c4166b7085343886ab69ef331238ff4 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 15 Sep 2020 16:15:34 +0100 Subject: [PATCH 03/12] Disable prometheus to unbreak tests --- roomserver/roomserver_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 5a67a1be..ef590100 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -172,7 +172,7 @@ func mustCreateRoomserverAPI(t *testing.T) (api.RoomserverInternalAPI, *dummyPro dp := &dummyProducer{ topic: cfg.Global.Kafka.TopicFor(config.TopicOutputRoomEvent), } - cache, err := caching.NewInMemoryLRUCache(true) + cache, err := caching.NewInMemoryLRUCache(false) if err != nil { t.Fatalf("failed to make caches: %s", err) } From 18231f25b437d2f03b3be1e0536fc46d45c8691f Mon Sep 17 00:00:00 2001 From: Kegsay Date: Wed, 16 Sep 2020 13:00:52 +0100 Subject: [PATCH 04/12] Implement rejected events (#1426) * WIP Event rejection * Still send back errors for rejected events Instead, discard them at the federationapi /send layer rather than re-implementing checks at the clientapi/PerformJoin layer. * Implement rejected events Critically, rejected events CAN cause state resolution to happen as it can merge forks in the DAG. This is fine, _provided_ we do not add the rejected event when performing state resolution, which is what this PR does. It also fixes the error handling when NotAllowed happens, as we were checking too early and needlessly handling NotAllowed in more than one place. * Update test to match reality * Modify InputRoomEvents to no longer return an error Errors do not serialise across HTTP boundaries in polylith mode, so instead set fields on the InputRoomEventsResponse. Add `Err()` function to make the API shape basically the same. * Remove redundant returns; linting * Update blacklist --- cmd/roomserver-integration-tests/main.go | 3 +- federationapi/routing/send.go | 9 ++---- federationapi/routing/send_test.go | 14 ++++----- roomserver/api/api.go | 2 +- roomserver/api/api_trace.go | 7 ++--- roomserver/api/input.go | 16 ++++++++++ roomserver/api/wrapper.go | 3 +- roomserver/internal/alias.go | 3 +- roomserver/internal/input/input.go | 8 +++-- roomserver/internal/input/input_events.go | 31 +++++++++++++------ .../internal/perform/perform_backfill.go | 2 +- roomserver/internal/perform/perform_invite.go | 3 +- roomserver/internal/perform/perform_join.go | 3 +- roomserver/internal/perform/perform_leave.go | 3 +- roomserver/internal/query/query.go | 1 + roomserver/inthttp/client.go | 7 +++-- roomserver/inthttp/server.go | 4 +-- roomserver/roomserver_test.go | 8 ----- roomserver/state/state.go | 5 +-- roomserver/storage/interface.go | 1 + roomserver/storage/postgres/events_table.go | 12 ++++--- roomserver/storage/shared/storage.go | 7 +++-- roomserver/storage/sqlite3/events_table.go | 13 +++++--- roomserver/storage/tables/interface.go | 5 ++- roomserver/types/types.go | 3 ++ sytest-blacklist | 11 ++----- sytest-whitelist | 4 ++- 27 files changed, 114 insertions(+), 74 deletions(-) diff --git a/cmd/roomserver-integration-tests/main.go b/cmd/roomserver-integration-tests/main.go index 43574778..41ea6f4d 100644 --- a/cmd/roomserver-integration-tests/main.go +++ b/cmd/roomserver-integration-tests/main.go @@ -215,7 +215,8 @@ func writeToRoomServer(input []string, roomserverURL string) error { if err != nil { return err } - return x.InputRoomEvents(context.Background(), &request, &response) + x.InputRoomEvents(context.Background(), &request, &response) + return response.Err() } // testRoomserver is used to run integration tests against a single roomserver. diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 9def7c3c..cb7bea6c 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -372,12 +372,9 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, is return t.processEventWithMissingState(ctx, e, stateResp.RoomVersion, isInboundTxn) } - // Check that the event is allowed by the state at the event. - if err := checkAllowedByState(e, gomatrixserverlib.UnwrapEventHeaders(stateResp.StateEvents)); err != nil { - return err - } - - // pass the event to the roomserver + // pass the event to the roomserver which will do auth checks + // If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently + // discarded by the caller of this function return api.SendEvents( context.Background(), t.rsAPI, diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index a714d07e..4f447f37 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -89,12 +89,11 @@ func (t *testRoomserverAPI) InputRoomEvents( ctx context.Context, request *api.InputRoomEventsRequest, response *api.InputRoomEventsResponse, -) error { +) { t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...) for _, ire := range request.InputRoomEvents { fmt.Println("InputRoomEvents: ", ire.Event.EventID()) } - return nil } func (t *testRoomserverAPI) PerformInvite( @@ -461,7 +460,8 @@ func TestBasicTransaction(t *testing.T) { assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) } -// The purpose of this test is to check that if the event received fails auth checks the transaction is failed. +// The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver +// as it does the auth check. func TestTransactionFailAuthChecks(t *testing.T) { rsAPI := &testRoomserverAPI{ queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { @@ -479,11 +479,9 @@ func TestTransactionFailAuthChecks(t *testing.T) { testData[len(testData)-1], // a message event } txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) - mustProcessTransaction(t, txn, []string{ - // expect the event to have an error - testEvents[len(testEvents)-1].EventID(), - }) - assertInputRoomEvents(t, rsAPI.inputRoomEvents, nil) // expect no messages to be sent to the roomserver + mustProcessTransaction(t, txn, []string{}) + // expect message to be sent to the roomserver + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) } // The purpose of this test is to make sure that when an event is received for which we do not know the prev_events, diff --git a/roomserver/api/api.go b/roomserver/api/api.go index eecefe32..2495157a 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -16,7 +16,7 @@ type RoomserverInternalAPI interface { ctx context.Context, request *InputRoomEventsRequest, response *InputRoomEventsResponse, - ) error + ) PerformInvite( ctx context.Context, diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 64330930..b7accb9a 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -23,10 +23,9 @@ func (t *RoomserverInternalAPITrace) InputRoomEvents( ctx context.Context, req *InputRoomEventsRequest, res *InputRoomEventsResponse, -) error { - err := t.Impl.InputRoomEvents(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("InputRoomEvents req=%+v res=%+v", js(req), js(res)) - return err +) { + t.Impl.InputRoomEvents(ctx, req, res) + util.GetLogger(ctx).Infof("InputRoomEvents req=%+v res=%+v", js(req), js(res)) } func (t *RoomserverInternalAPITrace) PerformInvite( diff --git a/roomserver/api/input.go b/roomserver/api/input.go index 651c0e9f..862a6fa1 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -16,6 +16,8 @@ package api import ( + "fmt" + "github.com/matrix-org/gomatrixserverlib" ) @@ -87,4 +89,18 @@ type InputRoomEventsRequest struct { // InputRoomEventsResponse is a response to InputRoomEvents type InputRoomEventsResponse struct { + ErrMsg string // set if there was any error + NotAllowed bool // true if an event in the input was not allowed. +} + +func (r *InputRoomEventsResponse) Err() error { + if r.ErrMsg == "" { + return nil + } + if r.NotAllowed { + return &gomatrixserverlib.NotAllowed{ + Message: r.ErrMsg, + } + } + return fmt.Errorf("InputRoomEventsResponse: %s", r.ErrMsg) } diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index e5339311..cc048ddd 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -187,7 +187,8 @@ func SendInputRoomEvents( ) error { request := InputRoomEventsRequest{InputRoomEvents: ires} var response InputRoomEventsResponse - return rsAPI.InputRoomEvents(ctx, &request, &response) + rsAPI.InputRoomEvents(ctx, &request, &response) + return response.Err() } // SendInvite event to the roomserver. diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index d576a817..3e023d2a 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -271,5 +271,6 @@ func (r *RoomserverInternalAPI) sendUpdatedAliasesEvent( var inputRes api.InputRoomEventsResponse // Send the request - return r.InputRoomEvents(ctx, &inputReq, &inputRes) + r.InputRoomEvents(ctx, &inputReq, &inputRes) + return inputRes.Err() } diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 51d20ad3..d340ac21 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -110,7 +110,7 @@ func (r *Inputer) InputRoomEvents( ctx context.Context, request *api.InputRoomEventsRequest, response *api.InputRoomEventsResponse, -) error { +) { // Create a wait group. Each task that we dispatch will call Done on // this wait group so that we know when all of our events have been // processed. @@ -156,8 +156,10 @@ func (r *Inputer) InputRoomEvents( // that back to the caller. for _, task := range tasks { if task.err != nil { - return task.err + response.ErrMsg = task.err.Error() + _, rejected := task.err.(*gomatrixserverlib.NotAllowed) + response.NotAllowed = rejected + return } } - return nil } diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index daf1afcd..0558cd76 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -46,10 +46,11 @@ func (r *Inputer) processRoomEvent( // Check that the event passes authentication checks and work out // the numeric IDs for the auth events. - authEventNIDs, err := helpers.CheckAuthEvents(ctx, r.DB, headered, input.AuthEventIDs) - if err != nil { - logrus.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event") - return + isRejected := false + authEventNIDs, rejectionErr := helpers.CheckAuthEvents(ctx, r.DB, headered, input.AuthEventIDs) + if rejectionErr != nil { + logrus.WithError(rejectionErr).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event, rejecting event") + isRejected = true } // If we don't have a transaction ID then get one. @@ -65,12 +66,13 @@ func (r *Inputer) processRoomEvent( } // Store the event. - _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) + _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs, isRejected) if err != nil { return "", fmt.Errorf("r.DB.StoreEvent: %w", err) } + // if storing this event results in it being redacted then do so. - if redactedEventID == event.EventID() { + if !isRejected && redactedEventID == event.EventID() { r, rerr := eventutil.RedactEvent(redactionEvent, &event) if rerr != nil { return "", fmt.Errorf("eventutil.RedactEvent: %w", rerr) @@ -101,12 +103,22 @@ func (r *Inputer) processRoomEvent( if stateAtEvent.BeforeStateSnapshotNID == 0 { // We haven't calculated a state for this event yet. // Lets calculate one. - err = r.calculateAndSetState(ctx, input, *roomInfo, &stateAtEvent, event) + err = r.calculateAndSetState(ctx, input, *roomInfo, &stateAtEvent, event, isRejected) if err != nil { return "", fmt.Errorf("r.calculateAndSetState: %w", err) } } + // We stop here if the event is rejected: We've stored it but won't update forward extremities or notify anyone about it. + if isRejected { + logrus.WithFields(logrus.Fields{ + "event_id": event.EventID(), + "type": event.Type(), + "room": event.RoomID(), + }).Debug("Stored rejected event") + return event.EventID(), rejectionErr + } + if input.Kind == api.KindRewrite { logrus.WithFields(logrus.Fields{ "event_id": event.EventID(), @@ -157,11 +169,12 @@ func (r *Inputer) calculateAndSetState( roomInfo types.RoomInfo, stateAtEvent *types.StateAtEvent, event gomatrixserverlib.Event, + isRejected bool, ) error { var err error roomState := state.NewStateResolution(r.DB, roomInfo) - if input.HasState { + if input.HasState && !isRejected { // Check here if we think we're in the room already. stateAtEvent.Overwrite = true var joinEventNIDs []types.EventNID @@ -188,7 +201,7 @@ func (r *Inputer) calculateAndSetState( stateAtEvent.Overwrite = false // We haven't been told what the state at the event is so we need to calculate it from the prev_events - if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event); err != nil { + if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event, isRejected); err != nil { return fmt.Errorf("roomState.CalculateAndStoreStateBeforeEvent: %w", err) } } diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 668c8078..eb1aa99b 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -535,7 +535,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []gomatrixse var stateAtEvent types.StateAtEvent var redactedEventID string var redactionEvent *gomatrixserverlib.Event - roomNID, stateAtEvent, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids) + roomNID, stateAtEvent, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids, false) if err != nil { logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event") continue diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index e06ad062..d6a64e7e 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -183,7 +183,8 @@ func (r *Inviter) PerformInvite( }, } inputRes := &api.InputRoomEventsResponse{} - if err = r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes); err != nil { + r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes) + if err = inputRes.Err(); err != nil { return nil, fmt.Errorf("r.InputRoomEvents: %w", err) } } else { diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index f76806c7..e9aebb83 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -247,7 +247,8 @@ func (r *Joiner) performJoinRoomByID( }, } inputRes := api.InputRoomEventsResponse{} - if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { + r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes) + if err = inputRes.Err(); err != nil { var notAllowed *gomatrixserverlib.NotAllowed if errors.As(err, ¬Allowed) { return "", &api.PerformError{ diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index aaa3b5b1..6aaf1bf3 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -139,7 +139,8 @@ func (r *Leaver) performLeaveRoomByID( }, } inputRes := api.InputRoomEventsResponse{} - if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { + r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes) + if err = inputRes.Err(); err != nil { return nil, fmt.Errorf("r.InputRoomEvents: %w", err) } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index b34ae770..fb981447 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -70,6 +70,7 @@ func (r *Queryer) QueryStateAfterEvents( if err != nil { switch err.(type) { case types.MissingEventError: + util.GetLogger(ctx).Errorf("QueryStateAfterEvents: MissingEventError: %s", err) return nil default: return err diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 1ff1fc82..f2510c75 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -149,12 +149,15 @@ func (h *httpRoomserverInternalAPI) InputRoomEvents( ctx context.Context, request *api.InputRoomEventsRequest, response *api.InputRoomEventsResponse, -) error { +) { span, ctx := opentracing.StartSpanFromContext(ctx, "InputRoomEvents") defer span.Finish() apiURL := h.roomserverURL + RoomserverInputRoomEventsPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) + if err != nil { + response.ErrMsg = err.Error() + } } func (h *httpRoomserverInternalAPI) PerformInvite( diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 5816d4d8..8ffa9cf9 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -20,9 +20,7 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { if err := json.NewDecoder(req.Body).Decode(&request); err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) } - if err := r.InputRoomEvents(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } + r.InputRoomEvents(req.Context(), &request, &response) return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index ef590100..912c5852 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -140,14 +140,6 @@ func mustCreateEvents(t *testing.T, roomVer gomatrixserverlib.RoomVersion, event return } -func eventsJSON(events []gomatrixserverlib.Event) []json.RawMessage { - result := make([]json.RawMessage, len(events)) - for i := range events { - result[i] = events[i].JSON() - } - return result -} - func mustLoadRawEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) []gomatrixserverlib.HeaderedEvent { t.Helper() hs := make([]gomatrixserverlib.HeaderedEvent, len(events)) diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 37e6807a..9ee6f40d 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -159,7 +159,7 @@ func (v StateResolution) LoadCombinedStateAfterEvents( } fullState = append(fullState, entries...) } - if prevState.IsStateEvent() { + if prevState.IsStateEvent() && !prevState.IsRejected { // If the prev event was a state event then add an entry for the event itself // so that we get the state after the event rather than the state before. fullState = append(fullState, prevState.StateEntry) @@ -523,6 +523,7 @@ func init() { func (v StateResolution) CalculateAndStoreStateBeforeEvent( ctx context.Context, event gomatrixserverlib.Event, + isRejected bool, ) (types.StateSnapshotNID, error) { // Load the state at the prev events. prevEventRefs := event.PrevEvents() @@ -561,7 +562,7 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents( if len(prevStates) == 1 { prevState := prevStates[0] - if prevState.EventStateKeyNID == 0 { + if prevState.EventStateKeyNID == 0 || prevState.IsRejected { // 3) None of the previous events were state events and they all // have the same state, so this event has exactly the same state // as the previous events. diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index be724da6..10a380e8 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -70,6 +70,7 @@ type Database interface { // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. StoreEvent( ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, + isRejected bool, ) (types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) // Look up the state entries for a list of string event IDs // Returns an error if the there is an error talking to the database diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index e66efb09..c8eb8e2d 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -65,13 +65,14 @@ CREATE TABLE IF NOT EXISTS roomserver_events ( -- Needed for setting reference hashes when sending new events. reference_sha256 BYTEA NOT NULL, -- A list of numeric IDs for events that can authenticate this event. - auth_event_nids BIGINT[] NOT NULL + auth_event_nids BIGINT[] NOT NULL, + is_rejected BOOLEAN NOT NULL DEFAULT FALSE ); ` const insertEventSQL = "" + - "INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth)" + - " VALUES ($1, $2, $3, $4, $5, $6, $7)" + + "INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)" + + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + " ON CONFLICT ON CONSTRAINT roomserver_event_id_unique" + " DO NOTHING" + " RETURNING event_nid, state_snapshot_nid" @@ -88,7 +89,7 @@ const bulkSelectStateEventByIDSQL = "" + " ORDER BY event_type_nid, event_state_key_nid ASC" const bulkSelectStateAtEventByIDSQL = "" + - "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid FROM roomserver_events" + + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + " WHERE event_id = ANY($1)" const updateEventStateSQL = "" + @@ -174,12 +175,14 @@ func (s *eventStatements) InsertEvent( referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, + isRejected bool, ) (types.EventNID, types.StateSnapshotNID, error) { var eventNID int64 var stateNID int64 err := s.insertEventStmt.QueryRowContext( ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, + isRejected, ).Scan(&eventNID, &stateNID) return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err } @@ -255,6 +258,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID( &result.EventStateKeyNID, &result.EventNID, &result.BeforeStateSnapshotNID, + &result.IsRejected, ); err != nil { return nil, err } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 262b0f2f..e710b99b 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -382,7 +382,7 @@ func (d *Database) GetLatestEventsForUpdate( // nolint:gocyclo func (d *Database) StoreEvent( ctx context.Context, event gomatrixserverlib.Event, - txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, + txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, isRejected bool, ) (types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { var ( roomNID types.RoomNID @@ -446,6 +446,7 @@ func (d *Database) StoreEvent( event.EventReference().EventSHA256, authEventNIDs, event.Depth(), + isRejected, ); err != nil { if err == sql.ErrNoRows { // We've already inserted the event so select the numeric event ID @@ -459,7 +460,9 @@ func (d *Database) StoreEvent( if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) } - redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event) + if !isRejected { // ignore rejected redaction events + redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event) + } return nil }) if err != nil { diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index a866c85d..773e9ade 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -41,13 +41,14 @@ const eventsSchema = ` depth INTEGER NOT NULL, event_id TEXT NOT NULL UNIQUE, reference_sha256 BLOB NOT NULL, - auth_event_nids TEXT NOT NULL DEFAULT '[]' + auth_event_nids TEXT NOT NULL DEFAULT '[]', + is_rejected BOOLEAN NOT NULL DEFAULT FALSE ); ` const insertEventSQL = ` - INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth) - VALUES ($1, $2, $3, $4, $5, $6, $7) + INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT DO NOTHING; ` @@ -63,7 +64,7 @@ const bulkSelectStateEventByIDSQL = "" + " ORDER BY event_type_nid, event_state_key_nid ASC" const bulkSelectStateAtEventByIDSQL = "" + - "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid FROM roomserver_events" + + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + " WHERE event_id IN ($1)" const updateEventStateSQL = "" + @@ -150,13 +151,14 @@ func (s *eventStatements) InsertEvent( referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, + isRejected bool, ) (types.EventNID, types.StateSnapshotNID, error) { // attempt to insert: the last_row_id is the event NID var eventNID int64 insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) result, err := insertStmt.ExecContext( ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), - eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, + eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, isRejected, ) if err != nil { return 0, 0, err @@ -261,6 +263,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID( &result.EventStateKeyNID, &result.EventNID, &result.BeforeStateSnapshotNID, + &result.IsRejected, ); err != nil { return nil, err } diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index adb06212..eba878ba 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -34,7 +34,10 @@ type EventStateKeys interface { } type Events interface { - InsertEvent(c context.Context, txn *sql.Tx, i types.RoomNID, j types.EventTypeNID, k types.EventStateKeyNID, eventID string, referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64) (types.EventNID, types.StateSnapshotNID, error) + InsertEvent( + ctx context.Context, txn *sql.Tx, i types.RoomNID, j types.EventTypeNID, k types.EventStateKeyNID, eventID string, + referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, isRejected bool, + ) (types.EventNID, types.StateSnapshotNID, error) SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error) // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError diff --git a/roomserver/types/types.go b/roomserver/types/types.go index f5b45763..c0fcef65 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -101,6 +101,9 @@ type StateAtEvent struct { Overwrite bool // The state before the event. BeforeStateSnapshotNID StateSnapshotNID + // True if this StateEntry is rejected. State resolution should then treat this + // StateEntry as being a message event (not a state event). + IsRejected bool // The state entry for the event itself, allows us to calculate the state after the event. StateEntry } diff --git a/sytest-blacklist b/sytest-blacklist index 705c9ff4..246e6830 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -40,11 +40,6 @@ Ignore invite in incremental sync New room members see their own join event Existing members see new members' join events -# Blacklisted because the federation work for these hasn't been finished yet. -Can recv device messages over federation -Device messages over federation wake up /sync -Wildcard device messages over federation wake up /sync - # See https://github.com/matrix-org/sytest/pull/901 Remote invited user can see room metadata @@ -56,8 +51,8 @@ Inbound federation accepts a second soft-failed event # Caused by https://github.com/matrix-org/sytest/pull/911 Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state -# We don't implement device lists yet -Device list doesn't change if remote server is down - # We don't implement lazy membership loading yet. The only membership state included in a gapped incremental sync is for senders in the timeline + +# flakey since implementing rejected events +Inbound federation correctly soft fails events \ No newline at end of file diff --git a/sytest-whitelist b/sytest-whitelist index 0adeaee6..91516428 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -470,4 +470,6 @@ We can't peek into rooms with shared history_visibility We can't peek into rooms with invited history_visibility We can't peek into rooms with joined history_visibility Local users can peek by room alias -Peeked rooms only turn up in the sync for the device who peeked them \ No newline at end of file +Peeked rooms only turn up in the sync for the device who peeked them +Room state at a rejected message event is the same as its predecessor +Room state at a rejected state event is the same as its predecessor \ No newline at end of file From 880b16449087cdadfa537e6ced4d1bb4ca703f24 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 21 Sep 2020 13:30:37 +0100 Subject: [PATCH 05/12] Refactor backoff again (#1431) * Tweak backoffs * Refactor backoff some more, remove BackoffIfRequired as it adds unnecessary complexity * Ignore 404s --- federationsender/queue/destinationqueue.go | 13 ++- federationsender/statistics/statistics.go | 90 ++++++++----------- .../statistics/statistics_test.go | 26 +++--- 3 files changed, 59 insertions(+), 70 deletions(-) diff --git a/federationsender/queue/destinationqueue.go b/federationsender/queue/destinationqueue.go index e9e117a7..57612908 100644 --- a/federationsender/queue/destinationqueue.go +++ b/federationsender/queue/destinationqueue.go @@ -231,13 +231,24 @@ func (oq *destinationQueue) backgroundSend() { // If we are backing off this server then wait for the // backoff duration to complete first, or until explicitly // told to retry. - if _, giveUp := oq.statistics.BackoffIfRequired(oq.backingOff, oq.interruptBackoff); giveUp { + until, blacklisted := oq.statistics.BackoffInfo() + if blacklisted { // It's been suggested that we should give up because the backoff // has exceeded a maximum allowable value. Clean up the in-memory // buffers at this point. The PDU clean-up is already on a defer. log.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination) return } + if until != nil { + // We haven't backed off yet, so wait for the suggested amount of + // time. + duration := time.Until(*until) + log.Warnf("Backing off %q for %s", oq.destination, duration) + select { + case <-time.After(duration): + case <-oq.interruptBackoff: + } + } // If we have pending PDUs or EDUs then construct a transaction. if pendingPDUs || pendingEDUs { diff --git a/federationsender/statistics/statistics.go b/federationsender/statistics/statistics.go index 03ef64e9..b5fe7513 100644 --- a/federationsender/statistics/statistics.go +++ b/federationsender/statistics/statistics.go @@ -44,6 +44,7 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS server = &ServerStatistics{ statistics: s, serverName: serverName, + interrupt: make(chan struct{}), } s.servers[serverName] = server s.mutex.Unlock() @@ -68,6 +69,7 @@ type ServerStatistics struct { backoffStarted atomic.Bool // is the backoff started backoffUntil atomic.Value // time.Time until this backoff interval ends backoffCount atomic.Uint32 // number of times BackoffDuration has been called + interrupt chan struct{} // interrupts the backoff goroutine successCounter atomic.Uint32 // how many times have we succeeded? } @@ -76,15 +78,24 @@ func (s *ServerStatistics) duration(count uint32) time.Duration { return time.Second * time.Duration(math.Exp2(float64(count))) } +// cancel will interrupt the currently active backoff. +func (s *ServerStatistics) cancel() { + s.blacklisted.Store(false) + s.backoffUntil.Store(time.Time{}) + select { + case s.interrupt <- struct{}{}: + default: + } +} + // Success updates the server statistics with a new successful // attempt, which increases the sent counter and resets the idle and // failure counters. If a host was blacklisted at this point then // we will unblacklist it. func (s *ServerStatistics) Success() { - s.successCounter.Add(1) - s.backoffStarted.Store(false) + s.cancel() + s.successCounter.Inc() s.backoffCount.Store(0) - s.blacklisted.Store(false) if s.statistics.DB != nil { if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) @@ -99,10 +110,30 @@ func (s *ServerStatistics) Success() { // whether we have blacklisted and therefore to give up. func (s *ServerStatistics) Failure() (time.Time, bool) { // If we aren't already backing off, this call will start - // a new backoff period. Reset the counter to 0 so that - // we backoff only for short periods of time to start with. + // a new backoff period. Increase the failure counter and + // start a goroutine which will wait out the backoff and + // unset the backoffStarted flag when done. if s.backoffStarted.CAS(false, true) { - s.backoffCount.Store(0) + if s.backoffCount.Inc() >= s.statistics.FailuresUntilBlacklist { + s.blacklisted.Store(true) + if s.statistics.DB != nil { + if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil { + logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName) + } + } + return time.Time{}, true + } + + go func() { + until, ok := s.backoffUntil.Load().(time.Time) + if ok { + select { + case <-time.After(time.Until(until)): + case <-s.interrupt: + } + } + s.backoffStarted.Store(false) + }() } // Check if we have blacklisted this node. @@ -136,53 +167,6 @@ func (s *ServerStatistics) BackoffInfo() (*time.Time, bool) { return nil, s.blacklisted.Load() } -// BackoffIfRequired will block for as long as the current -// backoff requires, if needed. Otherwise it will do nothing. -// Returns the amount of time to backoff for and whether to give up or not. -func (s *ServerStatistics) BackoffIfRequired(backingOff atomic.Bool, interrupt <-chan bool) (time.Duration, bool) { - if started := s.backoffStarted.Load(); !started { - return 0, false - } - - // Work out if we should be blacklisting at this point. - count := s.backoffCount.Inc() - if count >= s.statistics.FailuresUntilBlacklist { - // We've exceeded the maximum amount of times we're willing - // to back off, which is probably in the region of hours by - // now. Mark the host as blacklisted and tell the caller to - // give up. - s.blacklisted.Store(true) - if s.statistics.DB != nil { - if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil { - logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName) - } - } - return 0, true - } - - // Work out when we should wait until. - duration := s.duration(count) - until := time.Now().Add(duration) - s.backoffUntil.Store(until) - - // Notify the destination queue that we're backing off now. - backingOff.Store(true) - defer backingOff.Store(false) - - // Work out how long we should be backing off for. - logrus.Warnf("Backing off %q for %s", s.serverName, duration) - - // Wait for either an interruption or for the backoff to - // complete. - select { - case <-interrupt: - logrus.Debugf("Interrupting backoff for %q", s.serverName) - case <-time.After(duration): - } - - return duration, false -} - // Blacklisted returns true if the server is blacklisted and false // otherwise. func (s *ServerStatistics) Blacklisted() bool { diff --git a/federationsender/statistics/statistics_test.go b/federationsender/statistics/statistics_test.go index 7e083de6..225350b6 100644 --- a/federationsender/statistics/statistics_test.go +++ b/federationsender/statistics/statistics_test.go @@ -4,8 +4,6 @@ import ( "math" "testing" "time" - - "go.uber.org/atomic" ) func TestBackoff(t *testing.T) { @@ -27,34 +25,30 @@ func TestBackoff(t *testing.T) { server.Failure() t.Logf("Backoff counter: %d", server.backoffCount.Load()) - backingOff := atomic.Bool{} // Now we're going to simulate backing off a few times to see // what happens. for i := uint32(1); i <= 10; i++ { - // Interrupt the backoff - it doesn't really matter if it - // completes but we will find out how long the backoff should - // have been. - interrupt := make(chan bool, 1) - close(interrupt) - - // Get the duration. - duration, blacklist := server.BackoffIfRequired(backingOff, interrupt) - // Register another failure for good measure. This should have no // side effects since a backoff is already in progress. If it does // then we'll fail. until, blacklisted := server.Failure() - if time.Until(until) > duration { - t.Fatal("Failure produced unexpected side effect when it shouldn't have") - } + + // Get the duration. + _, blacklist := server.BackoffInfo() + duration := time.Until(until).Round(time.Second) + + // Unset the backoff, or otherwise our next call will think that + // there's a backoff in progress and return the same result. + server.cancel() + server.backoffStarted.Store(false) // Check if we should be blacklisted by now. if i >= stats.FailuresUntilBlacklist { if !blacklist { t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i) } else if blacklist != blacklisted { - t.Fatalf("BackoffIfRequired and Failure returned different blacklist values") + t.Fatalf("BackoffInfo and Failure returned different blacklist values") } else { t.Logf("Backoff %d is blacklisted as expected", i) continue From a06c18bb562749db1a175a6295e995ec877f1c92 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 21 Sep 2020 14:55:46 +0100 Subject: [PATCH 06/12] Soft-fail (#1364) * Initial work on soft-fail * Fix state block retrieval * Copy-pasta QueryLatestEventsAndState code * Fix state lookup * Clean up * Fix up failing sytest * Linting * Update previous events SQLite insert query * Update SQLite InsertPreviousEvent properly * Hopefully fix the event references updates Co-authored-by: Kegan Dougal --- roomserver/api/wrapper.go | 10 +-- roomserver/internal/helpers/auth.go | 65 +++++++++++++++++++ roomserver/internal/input/input_events.go | 25 +++++-- .../storage/sqlite3/previous_events_table.go | 41 ++++++++++-- sytest-blacklist | 5 +- sytest-whitelist | 4 +- 6 files changed, 128 insertions(+), 22 deletions(-) diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index cc048ddd..24949fc6 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -122,15 +122,7 @@ func SendEventWithRewrite( // We will handle an event as if it's an outlier if one of the // following conditions is true: storeAsOutlier := false - if authOrStateEvent.Type() == event.Type() && *authOrStateEvent.StateKey() == *event.StateKey() { - // The event is a state event but the input event is going to - // replace it, therefore it can't be added to the state or we'll - // get duplicate state keys in the state block. We'll send it - // as an outlier because we don't know if something will be - // referring to it as an auth event, but need it to be stored - // just in case. - storeAsOutlier = true - } else if _, ok := isCurrentState[authOrStateEvent.EventID()]; !ok { + if _, ok := isCurrentState[authOrStateEvent.EventID()]; !ok { // The event is an auth event and isn't a part of the state set. // We'll send it as an outlier because we need it to be stored // in case something is referring to it as an auth event. diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 524a5451..834bc0c6 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -16,13 +16,78 @@ package helpers import ( "context" + "fmt" "sort" + "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) +// CheckForSoftFail returns true if the event should be soft-failed +// and false otherwise. The return error value should be checked before +// the soft-fail bool. +func CheckForSoftFail( + ctx context.Context, + db storage.Database, + event gomatrixserverlib.HeaderedEvent, + stateEventIDs []string, +) (bool, error) { + rewritesState := len(stateEventIDs) > 1 + + var authStateEntries []types.StateEntry + var err error + if rewritesState { + authStateEntries, err = db.StateEntriesForEventIDs(ctx, stateEventIDs) + if err != nil { + return true, fmt.Errorf("StateEntriesForEventIDs failed: %w", err) + } + } else { + // Work out if the room exists. + var roomInfo *types.RoomInfo + roomInfo, err = db.RoomInfo(ctx, event.RoomID()) + if err != nil { + return false, fmt.Errorf("db.RoomNID: %w", err) + } + if roomInfo == nil || roomInfo.IsStub { + return false, nil + } + + // Then get the state entries for the current state snapshot. + // We'll use this to check if the event is allowed right now. + roomState := state.NewStateResolution(db, *roomInfo) + authStateEntries, err = roomState.LoadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID) + if err != nil { + return true, fmt.Errorf("roomState.LoadStateAtSnapshot: %w", err) + } + } + + // As a special case, it's possible that the room will have no + // state because we haven't received a m.room.create event yet. + // If we're now processing the first create event then never + // soft-fail it. + if len(authStateEntries) == 0 && event.Type() == gomatrixserverlib.MRoomCreate { + return false, nil + } + + // Work out which of the state events we actually need. + stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()}) + + // Load the actual auth events from the database. + authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries) + if err != nil { + return true, fmt.Errorf("loadAuthEvents: %w", err) + } + + // Check if the event is allowed. + if err = gomatrixserverlib.Allowed(event.Event, &authEvents); err != nil { + // return true, nil + return true, fmt.Errorf("gomatrixserverlib.Allowed: %w", err) + } + return false, nil +} + // CheckAuthEvents checks that the event passes authentication checks // Returns the numeric IDs for the auth events. func CheckAuthEvents( diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 0558cd76..f953a925 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -53,6 +53,20 @@ func (r *Inputer) processRoomEvent( isRejected = true } + var softfail bool + if input.Kind == api.KindBackfill || input.Kind == api.KindNew { + // Check that the event passes authentication checks based on the + // current room state. + softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": event.EventID(), + "type": event.Type(), + "room": event.RoomID(), + }).WithError(err).Info("Error authing soft-failed event") + } + } + // If we don't have a transaction ID then get one. if input.TransactionID != nil { tdID := input.TransactionID @@ -88,6 +102,7 @@ func (r *Inputer) processRoomEvent( "event_id": event.EventID(), "type": event.Type(), "room": event.RoomID(), + "sender": event.Sender(), }).Debug("Stored outlier") return event.EventID(), nil } @@ -110,11 +125,13 @@ func (r *Inputer) processRoomEvent( } // We stop here if the event is rejected: We've stored it but won't update forward extremities or notify anyone about it. - if isRejected { + if isRejected || softfail { logrus.WithFields(logrus.Fields{ - "event_id": event.EventID(), - "type": event.Type(), - "room": event.RoomID(), + "event_id": event.EventID(), + "type": event.Type(), + "room": event.RoomID(), + "soft_fail": softfail, + "sender": event.Sender(), }).Debug("Stored rejected event") return event.EventID(), rejectionErr } diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go index d28a42c6..222b53b9 100644 --- a/roomserver/storage/sqlite3/previous_events_table.go +++ b/roomserver/storage/sqlite3/previous_events_table.go @@ -18,6 +18,8 @@ package sqlite3 import ( "context" "database/sql" + "fmt" + "strings" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" @@ -25,10 +27,15 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" ) +// TODO: previous_reference_sha256 was NOT NULL before but it broke sytest because +// sytest sends no SHA256 sums in the prev_events references in the soft-fail tests. +// In Postgres an empty BYTEA field is not NULL so it's fine there. In SQLite it +// seems to care that it's empty and therefore hits a NOT NULL constraint on insert. +// We should really work out what the right thing to do here is. const previousEventSchema = ` CREATE TABLE IF NOT EXISTS roomserver_previous_events ( previous_event_id TEXT NOT NULL, - previous_reference_sha256 BLOB NOT NULL, + previous_reference_sha256 BLOB, event_nids TEXT NOT NULL, UNIQUE (previous_event_id, previous_reference_sha256) ); @@ -45,6 +52,11 @@ const insertPreviousEventSQL = ` VALUES ($1, $2, $3) ` +const selectPreviousEventNIDsSQL = ` + SELECT event_nids FROM roomserver_previous_events + WHERE previous_event_id = $1 AND previous_reference_sha256 = $2 +` + // Check if the event is referenced by another event in the table. // This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room. const selectPreviousEventExistsSQL = ` @@ -55,6 +67,7 @@ const selectPreviousEventExistsSQL = ` type previousEventStatements struct { db *sql.DB insertPreviousEventStmt *sql.Stmt + selectPreviousEventNIDsStmt *sql.Stmt selectPreviousEventExistsStmt *sql.Stmt } @@ -69,6 +82,7 @@ func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) { return s, shared.StatementList{ {&s.insertPreviousEventStmt, insertPreviousEventSQL}, + {&s.selectPreviousEventNIDsStmt, selectPreviousEventNIDsSQL}, {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, }.Prepare(db) } @@ -80,9 +94,28 @@ func (s *previousEventStatements) InsertPreviousEvent( previousEventReferenceSHA256 []byte, eventNID types.EventNID, ) error { - stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) - _, err := stmt.ExecContext( - ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), + var eventNIDs string + eventNIDAsString := fmt.Sprintf("%d", eventNID) + selectStmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt) + err := selectStmt.QueryRowContext(ctx, previousEventID, previousEventReferenceSHA256).Scan(&eventNIDs) + if err != sql.ErrNoRows { + return fmt.Errorf("selectStmt.QueryRowContext.Scan: %w", err) + } + var nids []string + if eventNIDs != "" { + nids = strings.Split(eventNIDs, ",") + for _, nid := range nids { + if nid == eventNIDAsString { + return nil + } + } + eventNIDs = strings.Join(append(nids, eventNIDAsString), ",") + } else { + eventNIDs = eventNIDAsString + } + insertStmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) + _, err = insertStmt.ExecContext( + ctx, previousEventID, previousEventReferenceSHA256, eventNIDs, ) return err } diff --git a/sytest-blacklist b/sytest-blacklist index 246e6830..2f80fc78 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -52,7 +52,4 @@ Inbound federation accepts a second soft-failed event Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state # We don't implement lazy membership loading yet. -The only membership state included in a gapped incremental sync is for senders in the timeline - -# flakey since implementing rejected events -Inbound federation correctly soft fails events \ No newline at end of file +The only membership state included in a gapped incremental sync is for senders in the timeline \ No newline at end of file diff --git a/sytest-whitelist b/sytest-whitelist index 91516428..553df1f1 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -472,4 +472,6 @@ We can't peek into rooms with joined history_visibility Local users can peek by room alias Peeked rooms only turn up in the sync for the device who peeked them Room state at a rejected message event is the same as its predecessor -Room state at a rejected state event is the same as its predecessor \ No newline at end of file +Room state at a rejected state event is the same as its predecessor +Inbound federation correctly soft fails events +Inbound federation accepts a second soft-failed event \ No newline at end of file From 45de9dc1c04e544a663e198a1107bcddc5712726 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 21 Sep 2020 16:49:37 +0100 Subject: [PATCH 07/12] Use room version cache in Events() --- roomserver/storage/shared/storage.go | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index e710b99b..f8e733ab 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -320,9 +320,14 @@ func (d *Database) Events( if err != nil { return nil, err } - roomVersion, err = d.RoomsTable.SelectRoomVersionForRoomNID(ctx, roomNID) - if err != nil { - return nil, err + if roomID, ok := d.Cache.GetRoomServerRoomID(roomNID); ok { + roomVersion, _ = d.Cache.GetRoomVersion(roomID) + } + if roomVersion == "" { + roomVersion, err = d.RoomsTable.SelectRoomVersionForRoomNID(ctx, roomNID) + if err != nil { + return nil, err + } } result.Event, err = gomatrixserverlib.NewEventFromTrustedJSON( eventJSON.EventJSON, false, roomVersion, From a7563ede3d61efa626095b8b9069af9f16e7dd3d Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 22 Sep 2020 11:05:45 +0100 Subject: [PATCH 08/12] Process federated joins in background context (#1434) * Return early from federated room join * Synchronous perform-join as long as possible * Don't allow multiple federated joins to the same room by the same user --- federationsender/internal/api.go | 2 + federationsender/internal/perform.go | 78 +++++++++++++++++++++------- 2 files changed, 61 insertions(+), 19 deletions(-) diff --git a/federationsender/internal/api.go b/federationsender/internal/api.go index 2a70f7ed..49c53755 100644 --- a/federationsender/internal/api.go +++ b/federationsender/internal/api.go @@ -2,6 +2,7 @@ package internal import ( "context" + "sync" "time" "github.com/matrix-org/dendrite/federationsender/api" @@ -23,6 +24,7 @@ type FederationSenderInternalAPI struct { federation *gomatrixserverlib.FederationClient keyRing *gomatrixserverlib.KeyRing queues *queue.OutgoingQueues + joins sync.Map // joins currently in progress } func NewFederationSenderInternalAPI( diff --git a/federationsender/internal/perform.go b/federationsender/internal/perform.go index a0abf7ff..6aea296b 100644 --- a/federationsender/internal/perform.go +++ b/federationsender/internal/perform.go @@ -37,12 +37,32 @@ func (r *FederationSenderInternalAPI) PerformDirectoryLookup( return nil } +type federatedJoin struct { + UserID string + RoomID string +} + // PerformJoinRequest implements api.FederationSenderInternalAPI func (r *FederationSenderInternalAPI) PerformJoin( ctx context.Context, request *api.PerformJoinRequest, response *api.PerformJoinResponse, ) { + // Check that a join isn't already in progress for this user/room. + j := federatedJoin{request.UserID, request.RoomID} + if _, found := r.joins.Load(j); found { + response.LastError = &gomatrix.HTTPError{ + Code: 429, + Message: `{ + "errcode": "M_LIMIT_EXCEEDED", + "error": "There is already a federated join to this room in progress. Please wait for it to finish." + }`, // TODO: Why do none of our error types play nicely with each other? + } + return + } + r.joins.Store(j, nil) + defer r.joins.Delete(j) + // Look up the supported room versions. var supportedVersions []gomatrixserverlib.RoomVersion for version := range version.SupportedRoomVersions() { @@ -186,27 +206,47 @@ func (r *FederationSenderInternalAPI) performJoinUsingServer( } r.statistics.ForServer(serverName).Success() - // Check that the send_join response was valid. - joinCtx := perform.JoinContext(r.federation, r.keyRing) - respState, err := joinCtx.CheckSendJoinResponse( - ctx, event, serverName, respMakeJoin, respSendJoin, - ) - if err != nil { - return fmt.Errorf("joinCtx.CheckSendJoinResponse: %w", err) - } + // Process the join response in a goroutine. The idea here is + // that we'll try and wait for as long as possible for the work + // to complete, but if the client does give up waiting, we'll + // still continue to process the join anyway so that we don't + // waste the effort. + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(context.Background()) + go func() { + defer cancel() - // If we successfully performed a send_join above then the other - // server now thinks we're a part of the room. Send the newly - // returned state to the roomserver to update our local view. - if err = roomserverAPI.SendEventWithRewrite( - ctx, r.rsAPI, - respState, - event.Headered(respMakeJoin.RoomVersion), - nil, - ); err != nil { - return fmt.Errorf("r.producer.SendEventWithState: %w", err) - } + // Check that the send_join response was valid. + joinCtx := perform.JoinContext(r.federation, r.keyRing) + respState, err := joinCtx.CheckSendJoinResponse( + ctx, event, serverName, respMakeJoin, respSendJoin, + ) + if err != nil { + logrus.WithFields(logrus.Fields{ + "room_id": roomID, + "user_id": userID, + }).WithError(err).Error("Failed to process room join response") + return + } + // If we successfully performed a send_join above then the other + // server now thinks we're a part of the room. Send the newly + // returned state to the roomserver to update our local view. + if err = roomserverAPI.SendEventWithRewrite( + ctx, r.rsAPI, + respState, + event.Headered(respMakeJoin.RoomVersion), + nil, + ); err != nil { + logrus.WithFields(logrus.Fields{ + "room_id": roomID, + "user_id": userID, + }).WithError(err).Error("Failed to send room join response to roomserver") + return + } + }() + + <-ctx.Done() return nil } From a14b29b52617c06a548145a18b4d7cee6e529b79 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 22 Sep 2020 14:40:54 +0100 Subject: [PATCH 09/12] Initial notary support (#1436) * Initial work on notary support * Somewhat working (but not properly filtered) notary support, other tweaks * Update gomatrixserverlib --- federationapi/routing/keys.go | 62 +++++++++++++++++++++++++ federationapi/routing/routing.go | 22 +++++++++ federationsender/api/api.go | 2 + federationsender/internal/api.go | 24 ++++++++++ federationsender/inthttp/client.go | 72 +++++++++++++++++++++++++++--- federationsender/inthttp/server.go | 44 ++++++++++++++++++ go.mod | 2 +- go.sum | 4 +- serverkeyapi/internal/api.go | 2 +- serverkeyapi/serverkeyapi.go | 6 +-- sytest-whitelist | 4 +- 11 files changed, 229 insertions(+), 15 deletions(-) diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index f1ed4176..785be090 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -19,11 +19,14 @@ import ( "net/http" "time" + "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" + federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" "golang.org/x/crypto/ed25519" ) @@ -160,3 +163,62 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver return &keys, nil } + +func NotaryKeys( + httpReq *http.Request, cfg *config.FederationAPI, + fsAPI federationSenderAPI.FederationSenderInternalAPI, + req *gomatrixserverlib.PublicKeyNotaryLookupRequest, +) util.JSONResponse { + if req == nil { + req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{} + if reqErr := httputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil { + return *reqErr + } + } + + var response struct { + ServerKeys []json.RawMessage `json:"server_keys"` + } + response.ServerKeys = []json.RawMessage{} + + for serverName := range req.ServerKeys { + var keys *gomatrixserverlib.ServerKeys + if serverName == cfg.Matrix.ServerName { + if k, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil { + keys = k + } else { + return util.ErrorResponse(err) + } + } else { + if k, err := fsAPI.GetServerKeys(httpReq.Context(), serverName); err == nil { + keys = &k + } else { + return util.ErrorResponse(err) + } + } + if keys == nil { + continue + } + + j, err := json.Marshal(keys) + if err != nil { + logrus.WithError(err).Errorf("Failed to marshal %q response", serverName) + return jsonerror.InternalServerError() + } + + js, err := gomatrixserverlib.SignJSON( + string(cfg.Matrix.ServerName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, j, + ) + if err != nil { + logrus.WithError(err).Errorf("Failed to sign %q response", serverName) + return jsonerror.InternalServerError() + } + + response.ServerKeys = append(response.ServerKeys, js) + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: response, + } +} diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 71a09d42..06ed57af 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -61,6 +61,26 @@ func Setup( return LocalKeys(cfg) }) + notaryKeys := httputil.MakeExternalAPI("notarykeys", func(req *http.Request) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + var pkReq *gomatrixserverlib.PublicKeyNotaryLookupRequest + serverName := gomatrixserverlib.ServerName(vars["serverName"]) + keyID := gomatrixserverlib.KeyID(vars["keyID"]) + if serverName != "" && keyID != "" { + pkReq = &gomatrixserverlib.PublicKeyNotaryLookupRequest{ + ServerKeys: map[gomatrixserverlib.ServerName]map[gomatrixserverlib.KeyID]gomatrixserverlib.PublicKeyNotaryQueryCriteria{ + serverName: { + keyID: gomatrixserverlib.PublicKeyNotaryQueryCriteria{}, + }, + }, + } + } + return NotaryKeys(req, cfg, fsAPI, pkReq) + }) + // Ignore the {keyID} argument as we only have a single server key so we always // return that key. // Even if we had more than one server key, we would probably still ignore the @@ -68,6 +88,8 @@ func Setup( v2keysmux.Handle("/server/{keyID}", localKeys).Methods(http.MethodGet) v2keysmux.Handle("/server/", localKeys).Methods(http.MethodGet) v2keysmux.Handle("/server", localKeys).Methods(http.MethodGet) + v2keysmux.Handle("/query", notaryKeys).Methods(http.MethodPost) + v2keysmux.Handle("/query/{serverName}/{keyID}", notaryKeys).Methods(http.MethodGet) v1fedmux.Handle("/send/{txnID}", httputil.MakeFedAPI( "federation_send", cfg.Matrix.ServerName, keys, wakeup, diff --git a/federationsender/api/api.go b/federationsender/api/api.go index adc3b34c..5ae419be 100644 --- a/federationsender/api/api.go +++ b/federationsender/api/api.go @@ -20,6 +20,8 @@ type FederationClient interface { ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) + GetServerKeys(ctx context.Context, matrixServer gomatrixserverlib.ServerName) (gomatrixserverlib.ServerKeys, error) + LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) } // FederationClientError is returned from FederationClient methods in the event of a problem. diff --git a/federationsender/internal/api.go b/federationsender/internal/api.go index 49c53755..f9d35357 100644 --- a/federationsender/internal/api.go +++ b/federationsender/internal/api.go @@ -189,3 +189,27 @@ func (a *FederationSenderInternalAPI) GetEvent( } return ires.(gomatrixserverlib.Transaction), nil } + +func (a *FederationSenderInternalAPI) GetServerKeys( + ctx context.Context, s gomatrixserverlib.ServerName, +) (gomatrixserverlib.ServerKeys, error) { + ires, err := a.doRequest(s, func() (interface{}, error) { + return a.federation.GetServerKeys(ctx, s) + }) + if err != nil { + return gomatrixserverlib.ServerKeys{}, err + } + return ires.(gomatrixserverlib.ServerKeys), nil +} + +func (a *FederationSenderInternalAPI) LookupServerKeys( + ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, +) ([]gomatrixserverlib.ServerKeys, error) { + ires, err := a.doRequest(s, func() (interface{}, error) { + return a.federation.LookupServerKeys(ctx, s, keyRequests) + }) + if err != nil { + return []gomatrixserverlib.ServerKeys{}, err + } + return ires.([]gomatrixserverlib.ServerKeys), nil +} diff --git a/federationsender/inthttp/client.go b/federationsender/inthttp/client.go index 5bfe6089..e0783ee1 100644 --- a/federationsender/inthttp/client.go +++ b/federationsender/inthttp/client.go @@ -23,13 +23,15 @@ const ( FederationSenderPerformServersAlivePath = "/federationsender/performServersAlive" FederationSenderPerformBroadcastEDUPath = "/federationsender/performBroadcastEDU" - FederationSenderGetUserDevicesPath = "/federationsender/client/getUserDevices" - FederationSenderClaimKeysPath = "/federationsender/client/claimKeys" - FederationSenderQueryKeysPath = "/federationsender/client/queryKeys" - FederationSenderBackfillPath = "/federationsender/client/backfill" - FederationSenderLookupStatePath = "/federationsender/client/lookupState" - FederationSenderLookupStateIDsPath = "/federationsender/client/lookupStateIDs" - FederationSenderGetEventPath = "/federationsender/client/getEvent" + FederationSenderGetUserDevicesPath = "/federationsender/client/getUserDevices" + FederationSenderClaimKeysPath = "/federationsender/client/claimKeys" + FederationSenderQueryKeysPath = "/federationsender/client/queryKeys" + FederationSenderBackfillPath = "/federationsender/client/backfill" + FederationSenderLookupStatePath = "/federationsender/client/lookupState" + FederationSenderLookupStateIDsPath = "/federationsender/client/lookupStateIDs" + FederationSenderGetEventPath = "/federationsender/client/getEvent" + FederationSenderGetServerKeysPath = "/federationsender/client/getServerKeys" + FederationSenderLookupServerKeysPath = "/federationsender/client/lookupServerKeys" ) // NewFederationSenderClient creates a FederationSenderInternalAPI implemented by talking to a HTTP POST API. @@ -358,3 +360,59 @@ func (h *httpFederationSenderInternalAPI) GetEvent( } return *response.Res, nil } + +type getServerKeys struct { + S gomatrixserverlib.ServerName + ServerKeys gomatrixserverlib.ServerKeys + Err *api.FederationClientError +} + +func (h *httpFederationSenderInternalAPI) GetServerKeys( + ctx context.Context, s gomatrixserverlib.ServerName, +) (gomatrixserverlib.ServerKeys, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "GetServerKeys") + defer span.Finish() + + request := getServerKeys{ + S: s, + } + var response getServerKeys + apiURL := h.federationSenderURL + FederationSenderGetServerKeysPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) + if err != nil { + return gomatrixserverlib.ServerKeys{}, err + } + if response.Err != nil { + return gomatrixserverlib.ServerKeys{}, response.Err + } + return response.ServerKeys, nil +} + +type lookupServerKeys struct { + S gomatrixserverlib.ServerName + KeyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp + ServerKeys []gomatrixserverlib.ServerKeys + Err *api.FederationClientError +} + +func (h *httpFederationSenderInternalAPI) LookupServerKeys( + ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, +) ([]gomatrixserverlib.ServerKeys, error) { + span, ctx := opentracing.StartSpanFromContext(ctx, "LookupServerKeys") + defer span.Finish() + + request := lookupServerKeys{ + S: s, + KeyRequests: keyRequests, + } + var response lookupServerKeys + apiURL := h.federationSenderURL + FederationSenderLookupServerKeysPath + err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response) + if err != nil { + return []gomatrixserverlib.ServerKeys{}, err + } + if response.Err != nil { + return []gomatrixserverlib.ServerKeys{}, response.Err + } + return response.ServerKeys, nil +} diff --git a/federationsender/inthttp/server.go b/federationsender/inthttp/server.go index dfbff1c0..53e1183e 100644 --- a/federationsender/inthttp/server.go +++ b/federationsender/inthttp/server.go @@ -263,4 +263,48 @@ func AddRoutes(intAPI api.FederationSenderInternalAPI, internalAPIMux *mux.Route return util.JSONResponse{Code: http.StatusOK, JSON: request} }), ) + internalAPIMux.Handle( + FederationSenderGetServerKeysPath, + httputil.MakeInternalAPI("GetServerKeys", func(req *http.Request) util.JSONResponse { + var request getServerKeys + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + res, err := intAPI.GetServerKeys(req.Context(), request.S) + if err != nil { + ferr, ok := err.(*api.FederationClientError) + if ok { + request.Err = ferr + } else { + request.Err = &api.FederationClientError{ + Err: err.Error(), + } + } + } + request.ServerKeys = res + return util.JSONResponse{Code: http.StatusOK, JSON: request} + }), + ) + internalAPIMux.Handle( + FederationSenderLookupServerKeysPath, + httputil.MakeInternalAPI("LookupServerKeys", func(req *http.Request) util.JSONResponse { + var request lookupServerKeys + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + res, err := intAPI.LookupServerKeys(req.Context(), request.S, request.KeyRequests) + if err != nil { + ferr, ok := err.(*api.FederationClientError) + if ok { + request.Err = ferr + } else { + request.Err = &api.FederationClientError{ + Err: err.Error(), + } + } + } + request.ServerKeys = res + return util.JSONResponse{Code: http.StatusOK, JSON: request} + }), + ) } diff --git a/go.mod b/go.mod index 6b1c03b5..6d367bda 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd - github.com/matrix-org/gomatrixserverlib v0.0.0-20200907151926-38f437f2b2a6 + github.com/matrix-org/gomatrixserverlib v0.0.0-20200922131600-dce167edcce4 github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.2 diff --git a/go.sum b/go.sum index 5c4f27a5..990fa21a 100644 --- a/go.sum +++ b/go.sum @@ -569,8 +569,8 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20200907151926-38f437f2b2a6 h1:43gla6bLt4opWY1mQkAasF/LUCipZl7x2d44TY0wf40= -github.com/matrix-org/gomatrixserverlib v0.0.0-20200907151926-38f437f2b2a6/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200922131600-dce167edcce4 h1:jBUEVUTgXc5a9luTRvb9vOkuLB+F528CE3Z05nUzGeM= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200922131600-dce167edcce4/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= diff --git a/serverkeyapi/internal/api.go b/serverkeyapi/internal/api.go index 02028c60..bc02ac2d 100644 --- a/serverkeyapi/internal/api.go +++ b/serverkeyapi/internal/api.go @@ -20,7 +20,7 @@ type ServerKeyAPI struct { ServerKeyValidity time.Duration OurKeyRing gomatrixserverlib.KeyRing - FedClient *gomatrixserverlib.FederationClient + FedClient gomatrixserverlib.KeyClient } func (s *ServerKeyAPI) KeyRing() *gomatrixserverlib.KeyRing { diff --git a/serverkeyapi/serverkeyapi.go b/serverkeyapi/serverkeyapi.go index fbaaefad..783402b2 100644 --- a/serverkeyapi/serverkeyapi.go +++ b/serverkeyapi/serverkeyapi.go @@ -26,7 +26,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.ServerKeyInternalAPI, cach // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( cfg *config.ServerKeyAPI, - fedClient *gomatrixserverlib.FederationClient, + fedClient gomatrixserverlib.KeyClient, caches *caching.Caches, ) api.ServerKeyInternalAPI { innerDB, err := storage.NewDatabase( @@ -53,7 +53,7 @@ func NewInternalAPI( OurKeyRing: gomatrixserverlib.KeyRing{ KeyFetchers: []gomatrixserverlib.KeyFetcher{ &gomatrixserverlib.DirectKeyFetcher{ - Client: fedClient.Client, + Client: fedClient, }, }, KeyDatabase: serverKeyDB, @@ -65,7 +65,7 @@ func NewInternalAPI( perspective := &gomatrixserverlib.PerspectiveKeyFetcher{ PerspectiveServerName: ps.ServerName, PerspectiveServerKeys: map[gomatrixserverlib.KeyID]ed25519.PublicKey{}, - Client: fedClient.Client, + Client: fedClient, } for _, key := range ps.Keys { diff --git a/sytest-whitelist b/sytest-whitelist index 553df1f1..84706b6c 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -474,4 +474,6 @@ Peeked rooms only turn up in the sync for the device who peeked them Room state at a rejected message event is the same as its predecessor Room state at a rejected state event is the same as its predecessor Inbound federation correctly soft fails events -Inbound federation accepts a second soft-failed event \ No newline at end of file +Inbound federation accepts a second soft-failed event +Federation key API can act as a notary server via a POST request +Federation key API can act as a notary server via a GET request From a854e3aa18ccb9314b5ea1113ce932981c74c805 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 22 Sep 2020 14:53:36 +0100 Subject: [PATCH 10/12] Fix backoff bug --- federationsender/queue/destinationqueue.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federationsender/queue/destinationqueue.go b/federationsender/queue/destinationqueue.go index 57612908..12a04d4b 100644 --- a/federationsender/queue/destinationqueue.go +++ b/federationsender/queue/destinationqueue.go @@ -239,7 +239,7 @@ func (oq *destinationQueue) backgroundSend() { log.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination) return } - if until != nil { + if until != nil && until.After(time.Now()) { // We haven't backed off yet, so wait for the suggested amount of // time. duration := time.Until(*until) From f908f8baab08bdb57e4d726f32182f40084f17c0 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 22 Sep 2020 16:41:46 +0100 Subject: [PATCH 11/12] Update gomatrixserverlib --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 6d367bda..1dd20a54 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd - github.com/matrix-org/gomatrixserverlib v0.0.0-20200922131600-dce167edcce4 + github.com/matrix-org/gomatrixserverlib v0.0.0-20200922152606-4aa1159e672b github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.2 diff --git a/go.sum b/go.sum index 990fa21a..e3dd32fe 100644 --- a/go.sum +++ b/go.sum @@ -569,8 +569,8 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20200922131600-dce167edcce4 h1:jBUEVUTgXc5a9luTRvb9vOkuLB+F528CE3Z05nUzGeM= -github.com/matrix-org/gomatrixserverlib v0.0.0-20200922131600-dce167edcce4/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200922152606-4aa1159e672b h1:I8H9ftkT1K/OA2urt/dfXAYpO3pOiMQL5bvoWm4i0RA= +github.com/matrix-org/gomatrixserverlib v0.0.0-20200922152606-4aa1159e672b/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= From de8b39065ec6d56d6784ce3b704f00432b41e6fb Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 23 Sep 2020 11:07:57 +0100 Subject: [PATCH 12/12] Enforce valid key IDs (#1437) * Enforce valid key IDs * Don't use key_id from dendrite.yaml as it is in matrix_key.pem --- dendrite-config.yaml | 3 --- internal/config/config.go | 6 ++++++ internal/config/config_global.go | 2 +- internal/test/config.go | 7 ++++++- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/dendrite-config.yaml b/dendrite-config.yaml index be0972e4..8c737692 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -38,9 +38,6 @@ global: # The path to the signing private key file, used to sign requests and events. private_key: matrix_key.pem - # A unique identifier for this private key. Must start with the prefix "ed25519:". - key_id: ed25519:auto - # How long a remote server can cache our server signing key before requesting it # again. Increasing this number will reduce the number of requests made by other # servers for our key but increases the period that a compromised key will be diff --git a/internal/config/config.go b/internal/config/config.go index d7470f87..d75500db 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -36,6 +36,9 @@ import ( jaegermetrics "github.com/uber/jaeger-lib/metrics" ) +// keyIDRegexp defines allowable characters in Key IDs. +var keyIDRegexp = regexp.MustCompile("^ed25519:[a-zA-Z0-9_]+$") + // Version is the current version of the config format. // This will change whenever we make breaking changes to the config format. const Version = 1 @@ -459,6 +462,9 @@ func readKeyPEM(path string, data []byte) (gomatrixserverlib.KeyID, ed25519.Priv if !strings.HasPrefix(keyID, "ed25519:") { return "", nil, fmt.Errorf("key ID %q doesn't start with \"ed25519:\" in %q", keyID, path) } + if !keyIDRegexp.MatchString(keyID) { + return "", nil, fmt.Errorf("key ID %q in %q contains illegal characters (use a-z, A-Z, 0-9 and _ only)", keyID, path) + } _, privKey, err := ed25519.GenerateKey(bytes.NewReader(keyBlock.Bytes)) if err != nil { return "", nil, err diff --git a/internal/config/config_global.go b/internal/config/config_global.go index 2b36da2f..03f522be 100644 --- a/internal/config/config_global.go +++ b/internal/config/config_global.go @@ -20,7 +20,7 @@ type Global struct { // An arbitrary string used to uniquely identify the PrivateKey. Must start with the // prefix "ed25519:". - KeyID gomatrixserverlib.KeyID `yaml:"key_id"` + KeyID gomatrixserverlib.KeyID `yaml:"-"` // How long a remote server can cache our server key for before requesting it again. // Increasing this number will reduce the number of requests made by remote servers diff --git a/internal/test/config.go b/internal/test/config.go index 72cd0e6e..8080988f 100644 --- a/internal/test/config.go +++ b/internal/test/config.go @@ -25,6 +25,7 @@ import ( "math/big" "os" "path/filepath" + "strings" "time" "github.com/matrix-org/dendrite/internal/config" @@ -146,10 +147,14 @@ func NewMatrixKey(matrixKeyPath string) (err error) { err = keyOut.Close() })() + keyID := base64.RawURLEncoding.EncodeToString(data[:]) + keyID = strings.ReplaceAll(keyID, "-", "") + keyID = strings.ReplaceAll(keyID, "_", "") + err = pem.Encode(keyOut, &pem.Block{ Type: "MATRIX PRIVATE KEY", Headers: map[string]string{ - "Key-ID": "ed25519:" + base64.RawStdEncoding.EncodeToString(data[:3]), + "Key-ID": fmt.Sprintf("ed25519:%s", keyID[:6]), }, Bytes: data[3:], })