refactor: update GMSL (#3058)

Sister PR to https://github.com/matrix-org/gomatrixserverlib/pull/364

Read this commit by commit to avoid going insane.
This commit is contained in:
kegsay 2023-04-19 15:50:33 +01:00 committed by GitHub
parent 9fa39263c0
commit 72285b2659
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
306 changed files with 2117 additions and 1934 deletions

View file

@ -19,6 +19,7 @@ import (
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
)
type Database interface {
@ -29,19 +30,19 @@ type Database interface {
// Adds a new transaction_id: server_name mapping with associated json table nid to the queue
// entry table for each provided destination.
AssociateTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, dbReceipt *receipt.Receipt) error
AssociateTransactionWithDestinations(ctx context.Context, destinations map[spec.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, dbReceipt *receipt.Receipt) error
// Removes every server_name: receipt pair provided from the queue entries table.
// Will then remove every entry for each receipt provided from the queue json table.
// If any of the entries don't exist in either table, nothing will happen for that entry and
// an error will not be generated.
CleanTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*receipt.Receipt) error
CleanTransactions(ctx context.Context, userID spec.UserID, receipts []*receipt.Receipt) error
// Gets the oldest transaction for the provided server_name.
// If no transactions exist, returns nil and no error.
GetTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *receipt.Receipt, error)
GetTransaction(ctx context.Context, userID spec.UserID) (*gomatrixserverlib.Transaction, *receipt.Receipt, error)
// Gets the number of transactions being stored for the provided server_name.
// If the server doesn't exist in the database then 0 is returned with no error.
GetTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error)
GetTransactionCount(ctx context.Context, userID spec.UserID) (int64, error)
}

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
)
const relayQueueSchema = `
@ -90,7 +91,7 @@ func (s *relayQueueStatements) InsertQueueEntry(
ctx context.Context,
txn *sql.Tx,
transactionID gomatrixserverlib.TransactionID,
serverName gomatrixserverlib.ServerName,
serverName spec.ServerName,
nid int64,
) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt)
@ -106,7 +107,7 @@ func (s *relayQueueStatements) InsertQueueEntry(
func (s *relayQueueStatements) DeleteQueueEntries(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
serverName spec.ServerName,
jsonNIDs []int64,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteQueueEntriesStmt)
@ -117,7 +118,7 @@ func (s *relayQueueStatements) DeleteQueueEntries(
func (s *relayQueueStatements) SelectQueueEntries(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
serverName spec.ServerName,
limit int,
) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt)
@ -141,7 +142,7 @@ func (s *relayQueueStatements) SelectQueueEntries(
func (s *relayQueueStatements) SelectQueueEntryCount(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
serverName spec.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt)

View file

@ -21,7 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/shared"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
)
// Database stores information needed by the relayapi
@ -36,7 +36,7 @@ func NewDatabase(
conMan sqlutil.Connections,
dbProperties *config.DatabaseOptions,
cache caching.FederationCache,
isLocalServerName func(gomatrixserverlib.ServerName) bool,
isLocalServerName func(spec.ServerName) bool,
) (*Database, error) {
var d Database
var err error

View file

@ -25,11 +25,12 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
)
type Database struct {
DB *sql.DB
IsLocalServerName func(gomatrixserverlib.ServerName) bool
IsLocalServerName func(spec.ServerName) bool
Cache caching.FederationCache
Writer sqlutil.Writer
RelayQueue tables.RelayQueue
@ -61,7 +62,7 @@ func (d *Database) StoreTransaction(
func (d *Database) AssociateTransactionWithDestinations(
ctx context.Context,
destinations map[gomatrixserverlib.UserID]struct{},
destinations map[spec.UserID]struct{},
transactionID gomatrixserverlib.TransactionID,
dbReceipt *receipt.Receipt,
) error {
@ -88,7 +89,7 @@ func (d *Database) AssociateTransactionWithDestinations(
func (d *Database) CleanTransactions(
ctx context.Context,
userID gomatrixserverlib.UserID,
userID spec.UserID,
receipts []*receipt.Receipt,
) error {
nids := make([]int64, len(receipts))
@ -123,7 +124,7 @@ func (d *Database) CleanTransactions(
func (d *Database) GetTransaction(
ctx context.Context,
userID gomatrixserverlib.UserID,
userID spec.UserID,
) (*gomatrixserverlib.Transaction, *receipt.Receipt, error) {
entriesRequested := 1
nids, err := d.RelayQueue.SelectQueueEntries(ctx, nil, userID.Domain(), entriesRequested)
@ -160,7 +161,7 @@ func (d *Database) GetTransaction(
func (d *Database) GetTransactionCount(
ctx context.Context,
userID gomatrixserverlib.UserID,
userID spec.UserID,
) (int64, error) {
count, err := d.RelayQueue.SelectQueueEntryCount(ctx, nil, userID.Domain())
if err != nil {

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
)
const relayQueueSchema = `
@ -90,7 +91,7 @@ func (s *relayQueueStatements) InsertQueueEntry(
ctx context.Context,
txn *sql.Tx,
transactionID gomatrixserverlib.TransactionID,
serverName gomatrixserverlib.ServerName,
serverName spec.ServerName,
nid int64,
) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt)
@ -106,7 +107,7 @@ func (s *relayQueueStatements) InsertQueueEntry(
func (s *relayQueueStatements) DeleteQueueEntries(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
serverName spec.ServerName,
jsonNIDs []int64,
) error {
deleteSQL := strings.Replace(deleteQueueEntriesSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
@ -129,7 +130,7 @@ func (s *relayQueueStatements) DeleteQueueEntries(
func (s *relayQueueStatements) SelectQueueEntries(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
serverName spec.ServerName,
limit int,
) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt)
@ -153,7 +154,7 @@ func (s *relayQueueStatements) SelectQueueEntries(
func (s *relayQueueStatements) SelectQueueEntryCount(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
serverName spec.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt)

View file

@ -21,7 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/relayapi/storage/shared"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
)
// Database stores information needed by the federation sender
@ -36,7 +36,7 @@ func NewDatabase(
conMan sqlutil.Connections,
dbProperties *config.DatabaseOptions,
cache caching.FederationCache,
isLocalServerName func(gomatrixserverlib.ServerName) bool,
isLocalServerName func(spec.ServerName) bool,
) (*Database, error) {
var d Database
var err error

View file

@ -25,7 +25,7 @@ import (
"github.com/matrix-org/dendrite/relayapi/storage/postgres"
"github.com/matrix-org/dendrite/relayapi/storage/sqlite3"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
)
// NewDatabase opens a new database
@ -33,7 +33,7 @@ func NewDatabase(
conMan sqlutil.Connections,
dbProperties *config.DatabaseOptions,
cache caching.FederationCache,
isLocalServerName func(gomatrixserverlib.ServerName) bool,
isLocalServerName func(spec.ServerName) bool,
) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():

View file

@ -29,7 +29,7 @@ func NewDatabase(
conMan sqlutil.Connections,
dbProperties *config.DatabaseOptions,
cache caching.FederationCache,
isLocalServerName func(gomatrixserverlib.ServerName) bool,
isLocalServerName func(spec.ServerName) bool,
) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():

View file

@ -19,6 +19,7 @@ import (
"database/sql"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
)
// RelayQueue table contains a mapping of server name to transaction id and the corresponding nid.
@ -28,21 +29,21 @@ type RelayQueue interface {
// Adds a new transaction_id: server_name mapping with associated json table nid to the table.
// Will ensure only one transaction id is present for each server_name: nid mapping.
// Adding duplicates will silently do nothing.
InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error
InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName spec.ServerName, nid int64) error
// Removes multiple entries from the table corresponding the the list of nids provided.
// If any of the provided nids don't match a row in the table, that deletion is considered
// successful.
DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, jsonNIDs []int64) error
// Get a list of nids associated with the provided server name.
// Returns up to `limit` nids. The entries are returned oldest first.
// Will return an empty list if no matches were found.
SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName spec.ServerName, limit int) ([]int64, error)
// Get the number of entries in the table associated with the provided server name.
// If there are no matching rows, a count of 0 is returned with err set to nil.
SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName spec.ServerName) (int64, error)
}
// RelayQueueJSON table contains a map of nid to the raw transaction json.

View file

@ -27,11 +27,12 @@ import (
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/stretchr/testify/assert"
)
const (
testOrigin = gomatrixserverlib.ServerName("kaer.morhen")
testOrigin = spec.ServerName("kaer.morhen")
)
func mustCreateTransaction() gomatrixserverlib.Transaction {

View file

@ -28,6 +28,7 @@ import (
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/stretchr/testify/assert"
)
@ -73,7 +74,7 @@ func TestShoudInsertQueueTransaction(t *testing.T) {
defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain")
serverName := spec.ServerName("domain")
nid := int64(1)
err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
if err != nil {
@ -89,7 +90,7 @@ func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) {
defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain")
serverName := spec.ServerName("domain")
nid := int64(1)
err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
@ -114,7 +115,7 @@ func TestShouldRetrieveOldestInsertedQueueTransaction(t *testing.T) {
defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain")
serverName := spec.ServerName("domain")
nid := int64(2)
err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
if err != nil {
@ -122,7 +123,7 @@ func TestShouldRetrieveOldestInsertedQueueTransaction(t *testing.T) {
}
transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName = gomatrixserverlib.ServerName("domain")
serverName = spec.ServerName("domain")
oldestNID := int64(1)
err = db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, oldestNID)
if err != nil {
@ -155,7 +156,7 @@ func TestShouldDeleteQueueTransaction(t *testing.T) {
defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain")
serverName := spec.ServerName("domain")
nid := int64(1)
err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid)
@ -186,10 +187,10 @@ func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) {
defer close()
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
serverName := gomatrixserverlib.ServerName("domain")
serverName := spec.ServerName("domain")
nid := int64(1)
transactionID2 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d2", time.Now().UnixNano()))
serverName2 := gomatrixserverlib.ServerName("domain2")
serverName2 := spec.ServerName("domain2")
nid2 := int64(2)
transactionID3 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d3", time.Now().UnixNano()))