diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index 63175955..2a8528d5 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -46,13 +46,18 @@ CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( -- Unused in normal operation, but useful for background work or ad-hoc debugging. room_nid bigint NOT NULL, -- List of state_block_nids, stored sorted by state_block_nid. - state_block_nids bigint[] NOT NULL + state_block_nids bigint[] NOT NULL, + -- Deduplicate state snapshots if we can. + UNIQUE (room_nid, state_block_nids) ); ` const insertStateSQL = "" + "INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)" + " VALUES ($1, $2)" + + " ON CONFLICT (room_nid, state_block_nids) DO UPDATE SET room_nid=$1" + + // Performing an update, above, ensures that the RETURNING statement + // below will always return a valid state snapshot ID " RETURNING state_snapshot_nid" // Bulk state data NID lookup. @@ -87,7 +92,12 @@ func (s *stateSnapshotStatements) InsertState( for i := range stateBlockNIDs { nids[i] = int64(stateBlockNIDs[i]) } - err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) + var id int64 + err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids), int64(roomNID)).Scan(&id) + if err != nil { + return 0, err + } + stateNID = types.StateSnapshotNID(id) return } diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index bf49f62c..99e69113 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -33,13 +33,17 @@ const stateSnapshotSchema = ` CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT, room_nid INTEGER NOT NULL, - state_block_nids TEXT NOT NULL DEFAULT '[]' + state_block_nids TEXT NOT NULL DEFAULT '[]', + UNIQUE (room_nid, state_block_nids) ); ` const insertStateSQL = ` INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids) - VALUES ($1, $2);` + VALUES ($1, $2) + ON CONFLICT (room_nid, state_block_nids) DO UPDATE SET room_nid=$3 + RETURNING state_snapshot_nid +` // Bulk state data NID lookup. // Sorting by state_snapshot_nid means we can use binary search over the result @@ -77,15 +81,12 @@ func (s *stateSnapshotStatements) InsertState( return } insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt) - res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)) + var id int64 + err = insertStmt.QueryRowContext(ctx, int64(roomNID), string(stateBlockNIDsJSON), int64(roomNID)).Scan(&id) if err != nil { return 0, err } - lastRowID, err := res.LastInsertId() - if err != nil { - return 0, err - } - stateNID = types.StateSnapshotNID(lastRowID) + stateNID = types.StateSnapshotNID(id) return }