Enforce uniqueness for state snapshot to state block mappings

This commit is contained in:
Neil Alexander 2021-04-15 13:28:39 +01:00
parent 4a90bc86dd
commit dbd53fa9ff
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
2 changed files with 21 additions and 10 deletions

View file

@ -46,13 +46,18 @@ CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
-- Unused in normal operation, but useful for background work or ad-hoc debugging. -- Unused in normal operation, but useful for background work or ad-hoc debugging.
room_nid bigint NOT NULL, room_nid bigint NOT NULL,
-- List of state_block_nids, stored sorted by state_block_nid. -- List of state_block_nids, stored sorted by state_block_nid.
state_block_nids bigint[] NOT NULL state_block_nids bigint[] NOT NULL,
-- Deduplicate state snapshots if we can.
UNIQUE (room_nid, state_block_nids)
); );
` `
const insertStateSQL = "" + const insertStateSQL = "" +
"INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)" + "INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)" +
" VALUES ($1, $2)" + " VALUES ($1, $2)" +
" ON CONFLICT (room_nid, state_block_nids) DO UPDATE SET room_nid=$1" +
// Performing an update, above, ensures that the RETURNING statement
// below will always return a valid state snapshot ID
" RETURNING state_snapshot_nid" " RETURNING state_snapshot_nid"
// Bulk state data NID lookup. // Bulk state data NID lookup.
@ -87,7 +92,12 @@ func (s *stateSnapshotStatements) InsertState(
for i := range stateBlockNIDs { for i := range stateBlockNIDs {
nids[i] = int64(stateBlockNIDs[i]) nids[i] = int64(stateBlockNIDs[i])
} }
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) var id int64
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids), int64(roomNID)).Scan(&id)
if err != nil {
return 0, err
}
stateNID = types.StateSnapshotNID(id)
return return
} }

View file

@ -33,13 +33,17 @@ const stateSnapshotSchema = `
CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT, state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
room_nid INTEGER NOT NULL, room_nid INTEGER NOT NULL,
state_block_nids TEXT NOT NULL DEFAULT '[]' state_block_nids TEXT NOT NULL DEFAULT '[]',
UNIQUE (room_nid, state_block_nids)
); );
` `
const insertStateSQL = ` const insertStateSQL = `
INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids) INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)
VALUES ($1, $2);` VALUES ($1, $2)
ON CONFLICT (room_nid, state_block_nids) DO UPDATE SET room_nid=$3
RETURNING state_snapshot_nid
`
// Bulk state data NID lookup. // Bulk state data NID lookup.
// Sorting by state_snapshot_nid means we can use binary search over the result // Sorting by state_snapshot_nid means we can use binary search over the result
@ -77,15 +81,12 @@ func (s *stateSnapshotStatements) InsertState(
return return
} }
insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt) insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)) var id int64
err = insertStmt.QueryRowContext(ctx, int64(roomNID), string(stateBlockNIDsJSON), int64(roomNID)).Scan(&id)
if err != nil { if err != nil {
return 0, err return 0, err
} }
lastRowID, err := res.LastInsertId() stateNID = types.StateSnapshotNID(id)
if err != nil {
return 0, err
}
stateNID = types.StateSnapshotNID(lastRowID)
return return
} }