diff --git a/roomserver/state/state.go b/roomserver/state/state.go index c1b21c48..352871b1 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -463,7 +463,7 @@ func (v *StateResolution) CalculateAndStoreStateBeforeEvent( // Load the state at the prev events. prevStates, err := v.db.StateAtEventIDs(ctx, event.PrevEventIDs()) if err != nil { - return 0, err + return 0, fmt.Errorf("v.db.StateAtEventIDs: %w", err) } // The state before this event will be the state after the events that came before it. diff --git a/roomserver/storage/postgres/state_table.go b/roomserver/storage/postgres/state_table.go index b1573f63..e6e285d4 100644 --- a/roomserver/storage/postgres/state_table.go +++ b/roomserver/storage/postgres/state_table.go @@ -99,6 +99,7 @@ func (s *stateStatements) BulkSelectState( if err = rows.Scan(&stateNID, &eventNIDs); err != nil { return nil, err } + results[types.StateSnapshotNID(stateNID)] = []types.EventNID{} for _, id := range eventNIDs { results[types.StateSnapshotNID(stateNID)] = append( results[types.StateSnapshotNID(stateNID)], diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 1da63846..0da8d8ee 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -215,7 +215,7 @@ func (d *Database) StateEntries( ) ([]types.StateEntry, error) { nids, err := d.StateTable.BulkSelectState(ctx, []types.StateSnapshotNID{stateSnapshotNID}) if err != nil { - return nil, fmt.Errorf("d.StateTable.BulkSelectState: %w", err) + return nil, fmt.Errorf("d.StateTable.BulkSelectState: %w (ID %d)", err, stateSnapshotNID) } state, ok := nids[stateSnapshotNID] if !ok { diff --git a/roomserver/storage/sqlite3/state_table.go b/roomserver/storage/sqlite3/state_table.go index a7c33205..48220892 100644 --- a/roomserver/storage/sqlite3/state_table.go +++ b/roomserver/storage/sqlite3/state_table.go @@ -111,16 +111,11 @@ func (s *stateStatements) BulkSelectState( if err = rows.Scan(&stateNID, &eventNIDJSON); err != nil { return nil, fmt.Errorf("rows.Scan: %w", err) } - var eventNIDs []int64 + var eventNIDs []types.EventNID if err = json.Unmarshal(eventNIDJSON, &eventNIDs); err != nil { return nil, fmt.Errorf("json.Unmarshal: %w", err) } - for _, id := range eventNIDs { - results[types.StateSnapshotNID(stateNID)] = append( - results[types.StateSnapshotNID(stateNID)], - types.EventNID(id), - ) - } + results[types.StateSnapshotNID(stateNID)] = eventNIDs } if err = rows.Err(); err != nil { return nil, fmt.Errorf("rows.Err: %w", err)