From 5ada8872bb0a90492fbf84290da97ab9167bf61b Mon Sep 17 00:00:00 2001
From: Mark Haines <mjark@negativecurvature.net>
Date: Mon, 18 Sep 2017 14:15:17 +0100
Subject: [PATCH] Add context to the federationsender database (#231)

---
 .../federationsender/consumers/roomserver.go  |  8 +++++--
 .../storage/joined_hosts_table.go             | 22 ++++++++++++++-----
 .../federationsender/storage/room_table.go    | 21 +++++++++++++-----
 .../federationsender/storage/storage.go       | 18 +++++++++------
 4 files changed, 48 insertions(+), 21 deletions(-)

diff --git a/src/github.com/matrix-org/dendrite/federationsender/consumers/roomserver.go b/src/github.com/matrix-org/dendrite/federationsender/consumers/roomserver.go
index da19364e..629eb63c 100644
--- a/src/github.com/matrix-org/dendrite/federationsender/consumers/roomserver.go
+++ b/src/github.com/matrix-org/dendrite/federationsender/consumers/roomserver.go
@@ -123,8 +123,12 @@ func (s *OutputRoomEvent) processMessage(ore api.OutputNewRoomEvent) error {
 	// TODO: handle EventIDMismatchError and recover the current state by talking
 	// to the roomserver
 	oldJoinedHosts, err := s.db.UpdateRoom(
-		ore.Event.RoomID(), ore.LastSentEventID, ore.Event.EventID(),
-		addsJoinedHosts, ore.RemovesStateEventIDs,
+		context.TODO(),
+		ore.Event.RoomID(),
+		ore.LastSentEventID,
+		ore.Event.EventID(),
+		addsJoinedHosts,
+		ore.RemovesStateEventIDs,
 	)
 	if err != nil {
 		return err
diff --git a/src/github.com/matrix-org/dendrite/federationsender/storage/joined_hosts_table.go b/src/github.com/matrix-org/dendrite/federationsender/storage/joined_hosts_table.go
index fffcc7f3..3b9510d7 100644
--- a/src/github.com/matrix-org/dendrite/federationsender/storage/joined_hosts_table.go
+++ b/src/github.com/matrix-org/dendrite/federationsender/storage/joined_hosts_table.go
@@ -15,6 +15,7 @@
 package storage
 
 import (
+	"context"
 	"database/sql"
 
 	"github.com/lib/pq"
@@ -78,20 +79,29 @@ func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) {
 }
 
 func (s *joinedHostsStatements) insertJoinedHosts(
-	txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName,
+	ctx context.Context,
+	txn *sql.Tx,
+	roomID, eventID string,
+	serverName gomatrixserverlib.ServerName,
 ) error {
-	_, err := common.TxStmt(txn, s.insertJoinedHostsStmt).Exec(roomID, eventID, serverName)
+	stmt := common.TxStmt(txn, s.insertJoinedHostsStmt)
+	_, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
 	return err
 }
 
-func (s *joinedHostsStatements) deleteJoinedHosts(txn *sql.Tx, eventIDs []string) error {
-	_, err := common.TxStmt(txn, s.deleteJoinedHostsStmt).Exec(pq.StringArray(eventIDs))
+func (s *joinedHostsStatements) deleteJoinedHosts(
+	ctx context.Context, txn *sql.Tx, eventIDs []string,
+) error {
+	stmt := common.TxStmt(txn, s.deleteJoinedHostsStmt)
+	_, err := stmt.ExecContext(ctx, pq.StringArray(eventIDs))
 	return err
 }
 
-func (s *joinedHostsStatements) selectJoinedHosts(txn *sql.Tx, roomID string,
+func (s *joinedHostsStatements) selectJoinedHosts(
+	ctx context.Context, txn *sql.Tx, roomID string,
 ) ([]types.JoinedHost, error) {
-	rows, err := common.TxStmt(txn, s.selectJoinedHostsStmt).Query(roomID)
+	stmt := common.TxStmt(txn, s.selectJoinedHostsStmt)
+	rows, err := stmt.QueryContext(ctx, roomID)
 	if err != nil {
 		return nil, err
 	}
diff --git a/src/github.com/matrix-org/dendrite/federationsender/storage/room_table.go b/src/github.com/matrix-org/dendrite/federationsender/storage/room_table.go
index bcc0bb1d..bb52b707 100644
--- a/src/github.com/matrix-org/dendrite/federationsender/storage/room_table.go
+++ b/src/github.com/matrix-org/dendrite/federationsender/storage/room_table.go
@@ -15,6 +15,7 @@
 package storage
 
 import (
+	"context"
 	"database/sql"
 
 	"github.com/matrix-org/dendrite/common"
@@ -66,17 +67,22 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) {
 
 // insertRoom inserts the room if it didn't already exist.
 // If the room didn't exist then last_event_id is set to the empty string.
-func (s *roomStatements) insertRoom(txn *sql.Tx, roomID string) error {
-	_, err := common.TxStmt(txn, s.insertRoomStmt).Exec(roomID)
+func (s *roomStatements) insertRoom(
+	ctx context.Context, txn *sql.Tx, roomID string,
+) error {
+	_, err := common.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID)
 	return err
 }
 
 // selectRoomForUpdate locks the row for the room and returns the last_event_id.
 // The row must already exist in the table. Callers can ensure that the row
 // exists by calling insertRoom first.
-func (s *roomStatements) selectRoomForUpdate(txn *sql.Tx, roomID string) (string, error) {
+func (s *roomStatements) selectRoomForUpdate(
+	ctx context.Context, txn *sql.Tx, roomID string,
+) (string, error) {
 	var lastEventID string
-	err := common.TxStmt(txn, s.selectRoomForUpdateStmt).QueryRow(roomID).Scan(&lastEventID)
+	stmt := common.TxStmt(txn, s.selectRoomForUpdateStmt)
+	err := stmt.QueryRowContext(ctx, roomID).Scan(&lastEventID)
 	if err != nil {
 		return "", err
 	}
@@ -85,7 +91,10 @@ func (s *roomStatements) selectRoomForUpdate(txn *sql.Tx, roomID string) (string
 
 // updateRoom updates the last_event_id for the room. selectRoomForUpdate should
 // have already been called earlier within the transaction.
-func (s *roomStatements) updateRoom(txn *sql.Tx, roomID, lastEventID string) error {
-	_, err := common.TxStmt(txn, s.updateRoomStmt).Exec(roomID, lastEventID)
+func (s *roomStatements) updateRoom(
+	ctx context.Context, txn *sql.Tx, roomID, lastEventID string,
+) error {
+	stmt := common.TxStmt(txn, s.updateRoomStmt)
+	_, err := stmt.ExecContext(ctx, roomID, lastEventID)
 	return err
 }
diff --git a/src/github.com/matrix-org/dendrite/federationsender/storage/storage.go b/src/github.com/matrix-org/dendrite/federationsender/storage/storage.go
index 87458534..aa836efb 100644
--- a/src/github.com/matrix-org/dendrite/federationsender/storage/storage.go
+++ b/src/github.com/matrix-org/dendrite/federationsender/storage/storage.go
@@ -15,6 +15,7 @@
 package storage
 
 import (
+	"context"
 	"database/sql"
 
 	"github.com/matrix-org/dendrite/common"
@@ -73,35 +74,38 @@ func (d *Database) SetPartitionOffset(topic string, partition int32, offset int6
 // UpdateRoom updates the joined hosts for a room and returns what the joined
 // hosts were before the update.
 func (d *Database) UpdateRoom(
+	ctx context.Context,
 	roomID, oldEventID, newEventID string,
 	addHosts []types.JoinedHost,
 	removeHosts []string,
 ) (joinedHosts []types.JoinedHost, err error) {
 	err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
-		if err = d.insertRoom(txn, roomID); err != nil {
+		if err = d.insertRoom(ctx, txn, roomID); err != nil {
 			return err
 		}
-		lastSentEventID, err := d.selectRoomForUpdate(txn, roomID)
+		lastSentEventID, err := d.selectRoomForUpdate(ctx, txn, roomID)
 		if err != nil {
 			return err
 		}
 		if lastSentEventID != oldEventID {
-			return types.EventIDMismatchError{lastSentEventID, oldEventID}
+			return types.EventIDMismatchError{
+				DatabaseID: lastSentEventID, RoomServerID: oldEventID,
+			}
 		}
-		joinedHosts, err = d.selectJoinedHosts(txn, roomID)
+		joinedHosts, err = d.selectJoinedHosts(ctx, txn, roomID)
 		if err != nil {
 			return err
 		}
 		for _, add := range addHosts {
-			err = d.insertJoinedHosts(txn, roomID, add.MemberEventID, add.ServerName)
+			err = d.insertJoinedHosts(ctx, txn, roomID, add.MemberEventID, add.ServerName)
 			if err != nil {
 				return err
 			}
 		}
-		if err = d.deleteJoinedHosts(txn, removeHosts); err != nil {
+		if err = d.deleteJoinedHosts(ctx, txn, removeHosts); err != nil {
 			return err
 		}
-		return d.updateRoom(txn, roomID, newEventID)
+		return d.updateRoom(ctx, txn, roomID, newEventID)
 	})
 	return
 }