From 4e64c270dbe5d438325903e4404ed4b9ec43c039 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 17 Mar 2022 17:05:21 +0000 Subject: [PATCH] Various bug fixes and tweaks around invites and membership --- roomserver/internal/helpers/helpers.go | 2 +- roomserver/internal/perform/perform_invite.go | 7 ++-- roomserver/internal/perform/perform_leave.go | 4 +-- .../storage/postgres/membership_table.go | 10 ++++-- .../storage/shared/membership_updater.go | 10 +++--- .../storage/sqlite3/membership_table.go | 10 ++++-- roomserver/storage/tables/interface.go | 2 +- syncapi/storage/postgres/invites_table.go | 2 +- syncapi/storage/sqlite3/invites_table.go | 2 +- syncapi/storage/sqlite3/stream_id_table.go | 34 ++++--------------- 10 files changed, 36 insertions(+), 47 deletions(-) diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index 78a875c7..e67bbfca 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -28,7 +28,7 @@ func UpdateToInviteMembership( // reprocessing this event, or because the we received this invite from a // remote server via the federation invite API. In those cases we don't need // to send the event. - needsSending, err := mu.SetToInvite(*add) + needsSending, err := mu.SetToInvite(add) if err != nil { return nil, err } diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 6559cd08..6111372d 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/internal/input" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" @@ -139,13 +140,15 @@ func (r *Inviter) PerformInvite( // will never pass auth checks due to lacking room state, but we // still need to tell the client about the invite so we can accept // it, hence we return an output event to send to the sync api. - updater, err := r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion) + var updater *shared.MembershipUpdater + updater, err = r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion) if err != nil { return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err) } unwrapped := event.Unwrap() - outputUpdates, err := helpers.UpdateToInviteMembership(updater, unwrapped, nil, req.Event.RoomVersion) + var outputUpdates []api.OutputEvent + outputUpdates, err = helpers.UpdateToInviteMembership(updater, unwrapped, nil, req.Event.RoomVersion) if err != nil { return nil, fmt.Errorf("updateToInviteMembership: %w", err) } diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 49ddd481..1e5fb9f1 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -91,12 +91,12 @@ func (r *Leaver) performLeaveRoomByID( } // check that this is not a "server notice room" accData := &userapi.QueryAccountDataResponse{} - if err := r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{ + if err = r.UserAPI.QueryAccountData(ctx, &userapi.QueryAccountDataRequest{ UserID: req.UserID, RoomID: req.RoomID, DataType: "m.tag", }, accData); err != nil { - return nil, fmt.Errorf("unable to query account data") + return nil, fmt.Errorf("unable to query account data: %w", err) } if roomData, ok := accData.RoomAccountData[req.RoomID]; ok { diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 12717874..6ed5293e 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -276,11 +276,15 @@ func (s *membershipStatements) UpdateMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, forgotten bool, -) error { - _, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext( +) (bool, error) { + res, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext( ctx, roomNID, targetUserNID, senderUserNID, membership, eventNID, forgotten, ) - return err + if err != nil { + return false, err + } + rows, err := res.RowsAffected() + return rows > 0, err } func (s *membershipStatements) SelectRoomsWithMembership( diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index 8f3f3d63..b7db9f81 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -92,7 +92,7 @@ func (u *MembershipUpdater) IsKnock() bool { } // SetToInvite implements types.MembershipUpdater -func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { +func (u *MembershipUpdater) SetToInvite(event *gomatrixserverlib.Event) (bool, error) { var inserted bool err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) @@ -106,7 +106,7 @@ func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er return fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err) } if u.membership != tables.MembershipStateInvite { - if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, false); err != nil { + if inserted, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0, false); err != nil { return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } @@ -142,7 +142,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd } if u.membership != tables.MembershipStateJoin || isUpdate { - if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateJoin, nIDs[eventID], false); err != nil { + if _, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateJoin, nIDs[eventID], false); err != nil { return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } @@ -176,7 +176,7 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s } if u.membership != tables.MembershipStateLeaveOrBan { - if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateLeaveOrBan, nIDs[eventID], false); err != nil { + if _, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateLeaveOrBan, nIDs[eventID], false); err != nil { return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } @@ -201,7 +201,7 @@ func (u *MembershipUpdater) SetToKnock(event *gomatrixserverlib.Event) (bool, er return fmt.Errorf("u.d.EventNIDs: %w", err) } - if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateKnock, nIDs[event.EventID()], false); err != nil { + if inserted, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateKnock, nIDs[event.EventID()], false); err != nil { return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err) } } diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index 43567a94..7ed86b61 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -253,12 +253,16 @@ func (s *membershipStatements) UpdateMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState, eventNID types.EventNID, forgotten bool, -) error { +) (bool, error) { stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt) - _, err := stmt.ExecContext( + res, err := stmt.ExecContext( ctx, senderUserNID, membership, eventNID, forgotten, roomNID, targetUserNID, ) - return err + if err != nil { + return false, err + } + rows, err := res.RowsAffected() + return rows > 0, err } func (s *membershipStatements) SelectRoomsWithMembership( diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 04e3c96c..97e4afcf 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -125,7 +125,7 @@ type Membership interface { SelectMembershipFromRoomAndTarget(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error) SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error) SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) - UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error + UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) (bool, error) SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) // SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms. SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go index 48ad58c0..97001ae2 100644 --- a/syncapi/storage/postgres/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -52,7 +52,7 @@ const insertInviteEventSQL = "" + ") VALUES ($1, $2, $3, $4, FALSE) RETURNING id" const deleteInviteEventSQL = "" + - "UPDATE syncapi_invite_events SET deleted=TRUE, id=nextval('syncapi_stream_id') WHERE event_id = $1 RETURNING id" + "UPDATE syncapi_invite_events SET deleted=TRUE, id=nextval('syncapi_stream_id') WHERE event_id = $1 AND deleted=FALSE RETURNING id" const selectInviteEventsInRangeSQL = "" + "SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" + diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index 7498fd68..0a6823cc 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -47,7 +47,7 @@ const insertInviteEventSQL = "" + " VALUES ($1, $2, $3, $4, $5, false)" const deleteInviteEventSQL = "" + - "UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2" + "UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2 AND deleted=false" const selectInviteEventsInRangeSQL = "" + "SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" + diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go index b614271d..2be3ae93 100644 --- a/syncapi/storage/sqlite3/stream_id_table.go +++ b/syncapi/storage/sqlite3/stream_id_table.go @@ -27,15 +27,12 @@ INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("invite", 0) ` const increaseStreamIDStmt = "" + - "UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" - -const selectStreamIDStmt = "" + - "SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1" + "UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" + + " RETURNING stream_id" type streamIDStatements struct { db *sql.DB increaseStreamIDStmt *sql.Stmt - selectStreamIDStmt *sql.Stmt } func (s *streamIDStatements) prepare(db *sql.DB) (err error) { @@ -47,48 +44,29 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) { if s.increaseStreamIDStmt, err = db.Prepare(increaseStreamIDStmt); err != nil { return } - if s.selectStreamIDStmt, err = db.Prepare(selectStreamIDStmt); err != nil { - return - } return } func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) - selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) - if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil { - return - } - err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos) + err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos) return } func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) - selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) - if _, err = increaseStmt.ExecContext(ctx, "receipt"); err != nil { - return - } - err = selectStmt.QueryRowContext(ctx, "receipt").Scan(&pos) + err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos) return } func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) - selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) - if _, err = increaseStmt.ExecContext(ctx, "invite"); err != nil { - return - } - err = selectStmt.QueryRowContext(ctx, "invite").Scan(&pos) + err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos) return } func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) - selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) - if _, err = increaseStmt.ExecContext(ctx, "accountdata"); err != nil { - return - } - err = selectStmt.QueryRowContext(ctx, "accountdata").Scan(&pos) + err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos) return }