Always defer *sql.Rows.Close and consult with Err (#844)

* Always defer *sql.Rows.Close and consult with Err

database/sql.Rows.Next() makes sure to call Close only after exhausting
result rows which would NOT happen when returning early from a bad Scan.
Close being idempotent makes it a great candidate to get always deferred
regardless of what happens later on the result set.

This change also makes sure call Err() after exhausting Next() and
propagate non-nil results from it as the documentation advises.

Closes #764

Signed-off-by: Kiril Vladimiroff <kiril@vladimiroff.org>

* Override named result parameters in last returns

Signed-off-by: Kiril Vladimiroff <kiril@vladimiroff.org>

* Do the same over new changes that got merged

Signed-off-by: Kiril Vladimiroff <kiril@vladimiroff.org>

Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
This commit is contained in:
Kiril Vladimiroff 2020-02-11 16:12:21 +02:00 committed by GitHub
parent d45f869cdd
commit d5dbe546e4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 81 additions and 49 deletions

View file

@ -102,5 +102,5 @@ func (s *eventJSONStatements) bulkSelectEventJSON(
}
result.EventNID = types.EventNID(eventNID)
}
return results[:i], nil
return results[:i], rows.Err()
}

View file

@ -125,7 +125,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
}
result[stateKey] = types.EventStateKeyNID(stateKeyNID)
}
return result, nil
return result, rows.Err()
}
func (s *eventStateKeyStatements) bulkSelectEventStateKey(
@ -150,5 +150,5 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKey(
}
result[types.EventStateKeyNID(stateKeyNID)] = stateKey
}
return result, nil
return result, rows.Err()
}

View file

@ -143,5 +143,5 @@ func (s *eventTypeStatements) bulkSelectEventTypeNID(
}
result[eventType] = types.EventTypeNID(eventTypeNID)
}
return result, nil
return result, rows.Err()
}

View file

@ -209,6 +209,9 @@ func (s *eventStatements) bulkSelectStateEventByID(
return nil, err
}
}
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(eventIDs) {
// If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
// We don't know which ones were missing because we don't return the string IDs in the query.
@ -219,7 +222,7 @@ func (s *eventStatements) bulkSelectStateEventByID(
fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs)),
)
}
return results, err
return results, nil
}
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
@ -251,12 +254,15 @@ func (s *eventStatements) bulkSelectStateAtEventByID(
)
}
}
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(eventIDs) {
return nil, types.MissingEventError(
fmt.Sprintf("storage: event IDs missing from the database (%d != %d)", i, len(eventIDs)),
)
}
return results, err
return results, nil
}
func (s *eventStatements) updateEventState(
@ -321,6 +327,9 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference(
result.EventID = eventID
result.EventSHA256 = eventSHA256
}
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(eventNIDs) {
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
}
@ -343,6 +352,9 @@ func (s *eventStatements) bulkSelectEventReference(
return nil, err
}
}
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(eventNIDs) {
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
}
@ -366,6 +378,9 @@ func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []typ
}
results[types.EventNID(eventNID)] = eventID
}
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(eventNIDs) {
return nil, fmt.Errorf("storage: event NIDs missing from the database (%d != %d)", i, len(eventNIDs))
}
@ -389,7 +404,7 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []str
}
results[eventID] = types.EventNID(eventNID)
}
return results, nil
return results, rows.Err()
}
func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []types.EventNID) (int64, error) {

View file

@ -114,21 +114,23 @@ func (s *inviteStatements) insertInviteEvent(
func (s *inviteStatements) updateInviteRetired(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventIDs []string, err error) {
) ([]string, error) {
stmt := common.TxStmt(txn, s.updateInviteRetiredStmt)
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
if err != nil {
return nil, err
}
defer (func() { err = rows.Close() })()
defer rows.Close() // nolint: errcheck
var eventIDs []string
for rows.Next() {
var inviteEventID string
if err := rows.Scan(&inviteEventID); err != nil {
if err = rows.Scan(&inviteEventID); err != nil {
return nil, err
}
eventIDs = append(eventIDs, inviteEventID)
}
return
return eventIDs, rows.Err()
}
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs
@ -151,5 +153,5 @@ func (s *inviteStatements) selectInviteActiveForUserInRoom(
}
result = append(result, types.EventStateKeyNID(senderUserNID))
}
return result, nil
return result, rows.Err()
}

View file

@ -151,6 +151,7 @@ func (s *membershipStatements) selectMembershipsFromRoom(
if err != nil {
return
}
defer rows.Close() // nolint: errcheck
for rows.Next() {
var eNID types.EventNID
@ -159,8 +160,9 @@ func (s *membershipStatements) selectMembershipsFromRoom(
}
eventNIDs = append(eventNIDs, eNID)
}
return
return eventNIDs, rows.Err()
}
func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
ctx context.Context,
roomNID types.RoomNID, membership membershipState,
@ -170,6 +172,7 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
if err != nil {
return
}
defer rows.Close() // nolint: errcheck
for rows.Next() {
var eNID types.EventNID
@ -178,7 +181,7 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
}
eventNIDs = append(eventNIDs, eNID)
}
return
return eventNIDs, rows.Err()
}
func (s *membershipStatements) updateMembership(

View file

@ -90,23 +90,23 @@ func (s *roomAliasesStatements) selectRoomIDFromAlias(
func (s *roomAliasesStatements) selectAliasesFromRoomID(
ctx context.Context, roomID string,
) (aliases []string, err error) {
aliases = []string{}
) ([]string, error) {
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
if err != nil {
return
return nil, err
}
defer rows.Close() // nolint: errcheck
var aliases []string
for rows.Next() {
var alias string
if err = rows.Scan(&alias); err != nil {
return
return nil, err
}
aliases = append(aliases, alias)
}
return
return aliases, rows.Err()
}
func (s *roomAliasesStatements) selectCreatorIDFromAlias(

View file

@ -152,7 +152,7 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries(
eventNID int64
entry types.StateEntry
)
if err := rows.Scan(
if err = rows.Scan(
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
); err != nil {
return nil, err
@ -169,10 +169,13 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries(
}
current.StateEntries = append(current.StateEntries, entry)
}
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(stateBlockNIDs) {
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs))
}
return results, nil
return results, err
}
func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
@ -237,7 +240,7 @@ func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
if current.StateEntries != nil {
results = append(results, current)
}
return results, nil
return results, rows.Err()
}
func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array {

View file

@ -104,7 +104,7 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
for ; rows.Next(); i++ {
result := &results[i]
var stateBlockNIDs pq.Int64Array
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
return nil, err
}
result.StateBlockNIDs = make([]types.StateBlockNID, len(stateBlockNIDs))
@ -112,6 +112,9 @@ func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
result.StateBlockNIDs[k] = types.StateBlockNID(stateBlockNIDs[k])
}
}
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(stateNIDs) {
return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs))
}