Update database migrations, remove goose (#2264)

* Add new db migration

* Update migrations
Remove goose

* Add possibility to test direct upgrades

* Try to fix WASM test

* Add checks for specific migrations

* Remove AddMigration
Use WithTransaction
Add Dendrite version to table

* Fix linter issues

* Update tests

* Update comments, outdent if

* Namespace migrations

* Add direct upgrade tests, skipping over one version

* Split migrations

* Update go version in CI

* Fix copy&paste mistake

* Use contexts in migrations

Co-authored-by: kegsay <kegan@matrix.org>
Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
This commit is contained in:
Till 2022-07-25 11:39:22 +02:00 committed by GitHub
parent c7d978274d
commit 081f5e7226
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
58 changed files with 734 additions and 839 deletions

View file

@ -1,130 +1,142 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlutil
import (
"context"
"database/sql"
"fmt"
"runtime"
"sort"
"sync"
"time"
"github.com/matrix-org/dendrite/setup/config"
"github.com/pressly/goose"
"github.com/matrix-org/dendrite/internal"
"github.com/sirupsen/logrus"
)
type Migrations struct {
registeredGoMigrations map[int64]*goose.Migration
const createDBMigrationsSQL = "" +
"CREATE TABLE IF NOT EXISTS db_migrations (" +
" version TEXT PRIMARY KEY NOT NULL," +
" time TEXT NOT NULL," +
" dendrite_version TEXT NOT NULL" +
");"
const insertVersionSQL = "" +
"INSERT INTO db_migrations (version, time, dendrite_version)" +
" VALUES ($1, $2, $3)"
const selectDBMigrationsSQL = "SELECT version FROM db_migrations"
// Migration defines a migration to be run.
type Migration struct {
// Version is a simple description/name of this migration.
Version string
// Up defines the function to execute for an upgrade.
Up func(ctx context.Context, txn *sql.Tx) error
// Down defines the function to execute for a downgrade (not implemented yet).
Down func(ctx context.Context, txn *sql.Tx) error
}
func NewMigrations() *Migrations {
return &Migrations{
registeredGoMigrations: make(map[int64]*goose.Migration),
// Migrator
type Migrator struct {
db *sql.DB
migrations []Migration
knownMigrations map[string]struct{}
mutex *sync.Mutex
}
// NewMigrator creates a new DB migrator.
func NewMigrator(db *sql.DB) *Migrator {
return &Migrator{
db: db,
migrations: []Migration{},
knownMigrations: make(map[string]struct{}),
mutex: &sync.Mutex{},
}
}
// Copy-pasted from goose directly to store migrations into a map we control
// AddMigration adds a migration.
func (m *Migrations) AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) {
_, filename, _, _ := runtime.Caller(1)
m.AddNamedMigration(filename, up, down)
}
// AddNamedMigration : Add a named migration.
func (m *Migrations) AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.Tx) error) {
v, _ := goose.NumericComponent(filename)
migration := &goose.Migration{Version: v, Next: -1, Previous: -1, Registered: true, UpFn: up, DownFn: down, Source: filename}
if existing, ok := m.registeredGoMigrations[v]; ok {
panic(fmt.Sprintf("failed to add migration %q: version conflicts with %q", filename, existing.Source))
// AddMigrations appends migrations to the list of migrations. Migrations are executed
// in the order they are added to the list. De-duplicates migrations using their Version field.
func (m *Migrator) AddMigrations(migrations ...Migration) {
m.mutex.Lock()
defer m.mutex.Unlock()
for _, mig := range migrations {
if _, ok := m.knownMigrations[mig.Version]; !ok {
m.migrations = append(m.migrations, mig)
m.knownMigrations[mig.Version] = struct{}{}
}
}
m.registeredGoMigrations[v] = migration
}
// RunDeltas up to the latest version.
func (m *Migrations) RunDeltas(db *sql.DB, props *config.DatabaseOptions) error {
maxVer := goose.MaxVersion
minVer := int64(0)
migrations, err := m.collect(minVer, maxVer)
// Up executes all migrations in order they were added.
func (m *Migrator) Up(ctx context.Context) error {
var (
err error
dendriteVersion = internal.VersionString()
)
// ensure there is a table for known migrations
executedMigrations, err := m.ExecutedMigrations(ctx)
if err != nil {
return fmt.Errorf("runDeltas: Failed to collect migrations: %w", err)
return fmt.Errorf("unable to create/get migrations: %w", err)
}
if props.ConnectionString.IsPostgres() {
if err = goose.SetDialect("postgres"); err != nil {
return err
}
} else if props.ConnectionString.IsSQLite() {
if err = goose.SetDialect("sqlite3"); err != nil {
return err
}
} else {
return fmt.Errorf("unknown connection string: %s", props.ConnectionString)
}
for {
current, err := goose.EnsureDBVersion(db)
if err != nil {
return fmt.Errorf("runDeltas: Failed to EnsureDBVersion: %w", err)
}
next, err := migrations.Next(current)
if err != nil {
if err == goose.ErrNoNextVersion {
return nil
return WithTransaction(m.db, func(txn *sql.Tx) error {
for i := range m.migrations {
now := time.Now().UTC().Format(time.RFC3339)
migration := m.migrations[i]
logrus.Debugf("Executing database migration '%s'", migration.Version)
// Skip migration if it was already executed
if _, ok := executedMigrations[migration.Version]; ok {
continue
}
err = migration.Up(ctx, txn)
if err != nil {
return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err)
}
_, err = txn.ExecContext(ctx, insertVersionSQL,
migration.Version,
now,
dendriteVersion,
)
if err != nil {
return fmt.Errorf("unable to insert executed migrations: %w", err)
}
return fmt.Errorf("runDeltas: Failed to load next migration to %+v : %w", next, err)
}
if err = next.Up(db); err != nil {
return fmt.Errorf("runDeltas: Failed run migration: %w", err)
}
}
return nil
})
}
func (m *Migrations) collect(current, target int64) (goose.Migrations, error) {
var migrations goose.Migrations
// Go migrations registered via goose.AddMigration().
for _, migration := range m.registeredGoMigrations {
v, err := goose.NumericComponent(migration.Source)
if err != nil {
return nil, err
}
if versionFilter(v, current, target) {
migrations = append(migrations, migration)
// ExecutedMigrations returns a map with already executed migrations in addition to creating the
// migrations table, if it doesn't exist.
func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]struct{}, error) {
result := make(map[string]struct{})
_, err := m.db.ExecContext(ctx, createDBMigrationsSQL)
if err != nil {
return nil, fmt.Errorf("unable to create db_migrations: %w", err)
}
rows, err := m.db.QueryContext(ctx, selectDBMigrationsSQL)
if err != nil {
return nil, fmt.Errorf("unable to query db_migrations: %w", err)
}
defer internal.CloseAndLogIfError(ctx, rows, "ExecutedMigrations: rows.close() failed")
var version string
for rows.Next() {
if err = rows.Scan(&version); err != nil {
return nil, fmt.Errorf("unable to scan version: %w", err)
}
result[version] = struct{}{}
}
migrations = sortAndConnectMigrations(migrations)
return migrations, nil
}
func sortAndConnectMigrations(migrations goose.Migrations) goose.Migrations {
sort.Sort(migrations)
// now that we're sorted in the appropriate direction,
// populate next and previous for each migration
for i, m := range migrations {
prev := int64(-1)
if i > 0 {
prev = migrations[i-1].Version
migrations[i-1].Next = m.Version
}
migrations[i].Previous = prev
}
return migrations
}
func versionFilter(v, current, target int64) bool {
if target > current {
return v > current && v <= target
}
if target < current {
return v <= current && v > target
}
return false
return result, rows.Err()
}

View file

@ -0,0 +1,112 @@
package sqlutil_test
import (
"context"
"database/sql"
"fmt"
"reflect"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/test"
_ "github.com/mattn/go-sqlite3"
)
var dummyMigrations = []sqlutil.Migration{
{
Version: "init",
Up: func(ctx context.Context, txn *sql.Tx) error {
_, err := txn.ExecContext(ctx, "CREATE TABLE IF NOT EXISTS dummy ( test TEXT );")
return err
},
},
{
Version: "v2",
Up: func(ctx context.Context, txn *sql.Tx) error {
_, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test2 TEXT;")
return err
},
},
{
Version: "v2", // duplicate, this migration will be skipped
Up: func(ctx context.Context, txn *sql.Tx) error {
_, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test2 TEXT;")
return err
},
},
{
Version: "multiple execs",
Up: func(ctx context.Context, txn *sql.Tx) error {
_, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test3 TEXT;")
if err != nil {
return err
}
_, err = txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test4 TEXT;")
return err
},
},
}
var failMigration = sqlutil.Migration{
Version: "iFail",
Up: func(ctx context.Context, txn *sql.Tx) error {
return fmt.Errorf("iFail")
},
Down: nil,
}
func Test_migrations_Up(t *testing.T) {
withFail := append(dummyMigrations, failMigration)
tests := []struct {
name string
migrations []sqlutil.Migration
wantResult map[string]struct{}
wantErr bool
}{
{
name: "dummy migration",
migrations: dummyMigrations,
wantResult: map[string]struct{}{
"init": {},
"v2": {},
"multiple execs": {},
},
},
{
name: "with fail",
migrations: withFail,
wantErr: true,
},
}
ctx := context.Background()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
conStr, close := test.PrepareDBConnectionString(t, dbType)
defer close()
driverName := "sqlite3"
if dbType == test.DBTypePostgres {
driverName = "postgres"
}
db, err := sql.Open(driverName, conStr)
if err != nil {
t.Errorf("unable to open database: %v", err)
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(tt.migrations...)
if err = m.Up(ctx); (err != nil) != tt.wantErr {
t.Errorf("Up() error = %v, wantErr %v", err, tt.wantErr)
}
result, err := m.ExecutedMigrations(ctx)
if err != nil {
t.Errorf("unable to get executed migrations: %v", err)
}
if !tt.wantErr && !reflect.DeepEqual(result, tt.wantResult) {
t.Errorf("expected: %+v, got %v", tt.wantResult, result)
}
})
})
}
}