diff --git a/federationapi/federationapi_keys_test.go b/federationapi/federationapi_keys_test.go index 4469da35..85cc43aa 100644 --- a/federationapi/federationapi_keys_test.go +++ b/federationapi/federationapi_keys_test.go @@ -12,12 +12,13 @@ import ( "testing" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/routing" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/gomatrixserverlib" ) type server struct { @@ -86,7 +87,12 @@ func TestMain(m *testing.M) { cfg.Global.JetStream.StoragePath = config.Path(d) cfg.Global.KeyID = serverKeyID cfg.Global.KeyValidityPeriod = s.validity - cfg.FederationAPI.Database.ConnectionString = config.DataSource("file::memory:") + f, err := os.CreateTemp(d, "federation_keys_test*.db") + if err != nil { + return -1 + } + defer f.Close() + cfg.FederationAPI.Database.ConnectionString = config.DataSource("file:" + f.Name()) s.config = &cfg.FederationAPI // Create a transport which redirects federation requests to diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 15f7a684..e923143a 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -10,6 +10,10 @@ import ( "testing" "time" + "github.com/matrix-org/gomatrix" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/internal" @@ -20,9 +24,6 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" - "github.com/matrix-org/gomatrix" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" ) type fedRoomserverAPI struct { @@ -271,7 +272,6 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { cfg.Global.ServerName = gomatrixserverlib.ServerName("localhost") cfg.Global.PrivateKey = privKey cfg.Global.JetStream.InMemory = true - cfg.FederationAPI.Database.ConnectionString = config.DataSource("file::memory:") base := base.NewBaseDendrite(cfg, "Monolith") keyRing := &test.NopJSONVerifier{} // TODO: This is pretty fragile, as if anything calls anything on these nils this test will break. diff --git a/internal/sqlutil/migrate.go b/internal/sqlutil/migrate.go index b6a8b1f2..a66a7582 100644 --- a/internal/sqlutil/migrate.go +++ b/internal/sqlutil/migrate.go @@ -49,12 +49,13 @@ type Migration struct { Down func(ctx context.Context, txn *sql.Tx) error } -// Migrator +// Migrator contains fields required to run migrations. type Migrator struct { db *sql.DB migrations []Migration knownMigrations map[string]struct{} mutex *sync.Mutex + insertStmt *sql.Stmt } // NewMigrator creates a new DB migrator. @@ -82,35 +83,26 @@ func (m *Migrator) AddMigrations(migrations ...Migration) { // Up executes all migrations in order they were added. func (m *Migrator) Up(ctx context.Context) error { - var ( - err error - dendriteVersion = internal.VersionString() - ) // ensure there is a table for known migrations executedMigrations, err := m.ExecutedMigrations(ctx) if err != nil { return fmt.Errorf("unable to create/get migrations: %w", err) } - + // ensure we close the insert statement, as it's not needed anymore + defer m.close() return WithTransaction(m.db, func(txn *sql.Tx) error { for i := range m.migrations { - now := time.Now().UTC().Format(time.RFC3339) migration := m.migrations[i] // Skip migration if it was already executed if _, ok := executedMigrations[migration.Version]; ok { continue } logrus.Debugf("Executing database migration '%s'", migration.Version) - err = migration.Up(ctx, txn) - if err != nil { + + if err = migration.Up(ctx, txn); err != nil { return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err) } - _, err = txn.ExecContext(ctx, insertVersionSQL, - migration.Version, - now, - dendriteVersion, - ) - if err != nil { + if err = m.insertMigration(ctx, txn, migration.Version); err != nil { return fmt.Errorf("unable to insert executed migrations: %w", err) } } @@ -118,6 +110,23 @@ func (m *Migrator) Up(ctx context.Context) error { }) } +func (m *Migrator) insertMigration(ctx context.Context, txn *sql.Tx, migrationName string) error { + if m.insertStmt == nil { + stmt, err := m.db.Prepare(insertVersionSQL) + if err != nil { + return fmt.Errorf("unable to prepare insert statement: %w", err) + } + m.insertStmt = stmt + } + stmt := TxStmtContext(ctx, txn, m.insertStmt) + _, err := stmt.ExecContext(ctx, + migrationName, + time.Now().Format(time.RFC3339), + internal.VersionString(), + ) + return err +} + // ExecutedMigrations returns a map with already executed migrations in addition to creating the // migrations table, if it doesn't exist. func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]struct{}, error) { @@ -146,19 +155,20 @@ func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]struct{}, // inserts a migration given their name to the database. // This should only be used when manually inserting migrations. func InsertMigration(ctx context.Context, db *sql.DB, migrationName string) error { - _, err := db.ExecContext(ctx, createDBMigrationsSQL) + m := NewMigrator(db) + defer m.close() + existingMigrations, err := m.ExecutedMigrations(ctx) if err != nil { - return fmt.Errorf("unable to create db_migrations: %w", err) + return err } - _, err = db.ExecContext(ctx, insertVersionSQL, - migrationName, - time.Now().Format(time.RFC3339), - internal.VersionString(), - ) - // If the migration was already executed, we'll get a unique constraint error, - // return nil instead, to avoid unnecessary logging. - if IsUniqueConstraintViolationErr(err) { + if _, ok := existingMigrations[migrationName]; ok { return nil } - return err + return m.insertMigration(ctx, nil, migrationName) +} + +func (m *Migrator) close() { + if m.insertStmt != nil { + internal.CloseAndLogIfError(context.Background(), m.insertStmt, "unable to close insert statement") + } } diff --git a/internal/sqlutil/migrate_test.go b/internal/sqlutil/migrate_test.go index d8bcae19..5116237a 100644 --- a/internal/sqlutil/migrate_test.go +++ b/internal/sqlutil/migrate_test.go @@ -7,9 +7,10 @@ import ( "reflect" "testing" + _ "github.com/mattn/go-sqlite3" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/test" - _ "github.com/mattn/go-sqlite3" ) var dummyMigrations = []sqlutil.Migration{ @@ -81,11 +82,12 @@ func Test_migrations_Up(t *testing.T) { } ctx := context.Background() - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - conStr, close := test.PrepareDBConnectionString(t, dbType) - defer close() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + conStr, close := test.PrepareDBConnectionString(t, dbType) + defer close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { driverName := "sqlite3" if dbType == test.DBTypePostgres { driverName = "postgres" @@ -107,6 +109,30 @@ func Test_migrations_Up(t *testing.T) { t.Errorf("expected: %+v, got %v", tt.wantResult, result) } }) - }) - } + } + }) +} + +func Test_insertMigration(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + conStr, close := test.PrepareDBConnectionString(t, dbType) + defer close() + driverName := "sqlite3" + if dbType == test.DBTypePostgres { + driverName = "postgres" + } + + db, err := sql.Open(driverName, conStr) + if err != nil { + t.Errorf("unable to open database: %v", err) + } + + if err := sqlutil.InsertMigration(context.Background(), db, "testing"); err != nil { + t.Fatalf("unable to insert migration: %s", err) + } + // Second insert should not return an error, as it was already executed. + if err := sqlutil.InsertMigration(context.Background(), db, "testing"); err != nil { + t.Fatalf("unable to insert migration: %s", err) + } + }) }