Update database migrations, remove goose (#2264)

* Add new db migration

* Update migrations
Remove goose

* Add possibility to test direct upgrades

* Try to fix WASM test

* Add checks for specific migrations

* Remove AddMigration
Use WithTransaction
Add Dendrite version to table

* Fix linter issues

* Update tests

* Update comments, outdent if

* Namespace migrations

* Add direct upgrade tests, skipping over one version

* Split migrations

* Update go version in CI

* Fix copy&paste mistake

* Use contexts in migrations

Co-authored-by: kegsay <kegan@matrix.org>
Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
This commit is contained in:
Till 2022-07-25 11:39:22 +02:00 committed by GitHub
parent c7d978274d
commit 081f5e7226
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
58 changed files with 734 additions and 839 deletions

View file

@ -223,6 +223,31 @@ jobs:
- name: Test upgrade - name: Test upgrade
run: ./dendrite-upgrade-tests --head . run: ./dendrite-upgrade-tests --head .
# run database upgrade tests, skipping over one version
upgrade_test_direct:
name: Upgrade tests from HEAD-2
timeout-minutes: 20
needs: initial-tests-done
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Setup go
uses: actions/setup-go@v2
with:
go-version: "1.18"
- uses: actions/cache@v3
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-go-upgrade-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-upgrade
- name: Build upgrade-tests
run: go build ./cmd/dendrite-upgrade-tests
- name: Test upgrade
run: ./dendrite-upgrade-tests -direct -from HEAD-2 --head .
# run Sytest in different variations # run Sytest in different variations
sytest: sytest:
timeout-minutes: 20 timeout-minutes: 20
@ -359,7 +384,7 @@ jobs:
integration-tests-done: integration-tests-done:
name: Integration tests passed name: Integration tests passed
needs: [initial-tests-done, upgrade_test, sytest, complement] needs: [initial-tests-done, upgrade_test, upgrade_test_direct, sytest, complement]
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: ${{ !cancelled() }} # Run this even if prior jobs were skipped if: ${{ !cancelled() }} # Run this even if prior jobs were skipped
steps: steps:

View file

@ -37,6 +37,7 @@ var (
flagBuildConcurrency = flag.Int("build-concurrency", runtime.NumCPU(), "The amount of build concurrency when building images") flagBuildConcurrency = flag.Int("build-concurrency", runtime.NumCPU(), "The amount of build concurrency when building images")
flagHead = flag.String("head", "", "Location to a dendrite repository to treat as HEAD instead of Github") flagHead = flag.String("head", "", "Location to a dendrite repository to treat as HEAD instead of Github")
flagDockerHost = flag.String("docker-host", "localhost", "The hostname of the docker client. 'localhost' if running locally, 'host.docker.internal' if running in Docker.") flagDockerHost = flag.String("docker-host", "localhost", "The hostname of the docker client. 'localhost' if running locally, 'host.docker.internal' if running in Docker.")
flagDirect = flag.Bool("direct", false, "If a direct upgrade from the defined FROM version to TO should be done")
alphaNumerics = regexp.MustCompile("[^a-zA-Z0-9]+") alphaNumerics = regexp.MustCompile("[^a-zA-Z0-9]+")
) )
@ -229,7 +230,7 @@ func getAndSortVersionsFromGithub(httpClient *http.Client) (semVers []*semver.Ve
return semVers, nil return semVers, nil
} }
func calculateVersions(cli *http.Client, from, to string) []string { func calculateVersions(cli *http.Client, from, to string, direct bool) []string {
semvers, err := getAndSortVersionsFromGithub(cli) semvers, err := getAndSortVersionsFromGithub(cli)
if err != nil { if err != nil {
log.Fatalf("failed to collect semvers from github: %s", err) log.Fatalf("failed to collect semvers from github: %s", err)
@ -284,6 +285,9 @@ func calculateVersions(cli *http.Client, from, to string) []string {
if to == HEAD { if to == HEAD {
versions = append(versions, HEAD) versions = append(versions, HEAD)
} }
if direct {
versions = []string{versions[0], versions[len(versions)-1]}
}
return versions return versions
} }
@ -461,7 +465,7 @@ func main() {
os.Exit(1) os.Exit(1)
} }
cleanup(dockerClient) cleanup(dockerClient)
versions := calculateVersions(httpClient, *flagFrom, *flagTo) versions := calculateVersions(httpClient, *flagFrom, *flagTo, *flagDirect)
log.Printf("Testing dendrite versions: %v\n", versions) log.Printf("Testing dendrite versions: %v\n", versions)
branchToImageID := buildDendriteImages(httpClient, dockerClient, *flagTempDir, *flagBuildConcurrency, versions) branchToImageID := buildDendriteImages(httpClient, dockerClient, *flagTempDir, *flagBuildConcurrency, versions)

View file

@ -1,109 +0,0 @@
## Database migrations
We use [goose](https://github.com/pressly/goose) to handle database migrations. This allows us to execute
both SQL deltas (e.g `ALTER TABLE ...`) as well as manipulate data in the database in Go using Go functions.
To run a migration, the `goose` binary in this directory needs to be built:
```
$ go build ./cmd/goose
```
This binary allows Dendrite databases to be upgraded and downgraded. Sample usage for upgrading the roomserver database:
```
# for sqlite
$ ./goose -dir roomserver/storage/sqlite3/deltas sqlite3 ./roomserver.db up
# for postgres
$ ./goose -dir roomserver/storage/postgres/deltas postgres "user=dendrite dbname=dendrite sslmode=disable" up
```
For a full list of options, including rollbacks, see https://github.com/pressly/goose or use `goose` with no args.
### Rationale
Dendrite creates tables on startup using `CREATE TABLE IF NOT EXISTS`, so you might think that we should also
apply version upgrades on startup as well. This is convenient and doesn't involve an additional binary to run
which complicates upgrades. However, combining the upgrade mechanism and the server binary makes it difficult
to handle rollbacks. Firstly, how do you specify you wish to rollback? We would have to add additional flags
to the main server binary to say "rollback to version X". Secondly, if you roll back the server binary from
version 5 to version 4, the version 4 binary doesn't know how to rollback the database from version 5 to
version 4! For these reasons, we prefer to have a separate "upgrade" binary which is run for database upgrades.
Rather than roll-our-own migration tool, we decided to use [goose](https://github.com/pressly/goose) as it supports
complex migrations in Go code in addition to just executing SQL deltas. Other alternatives like
`github.com/golang-migrate/migrate` [do not support](https://github.com/golang-migrate/migrate/issues/15) these
kinds of complex migrations.
### Adding new deltas
You can add `.sql` or `.go` files manually or you can use goose to create them for you.
If you only want to add a SQL delta then run:
```
$ ./goose -dir serverkeyapi/storage/sqlite3/deltas sqlite3 ./foo.db create new_col sql
2020/09/09 14:37:43 Created new file: serverkeyapi/storage/sqlite3/deltas/20200909143743_new_col.sql
```
In this case, the version number is `20200909143743`. The important thing is that it is always increasing.
Then add up/downgrade SQL commands to the created file which looks like:
```sql
-- +goose Up
-- +goose StatementBegin
SELECT 'up SQL query';
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
SELECT 'down SQL query';
-- +goose StatementEnd
```
You __must__ keep the `+goose` annotations. You'll need to repeat this process for Postgres.
For complex Go migrations:
```
$ ./goose -dir serverkeyapi/storage/sqlite3/deltas sqlite3 ./foo.db create complex_update go
2020/09/09 14:40:38 Created new file: serverkeyapi/storage/sqlite3/deltas/20200909144038_complex_update.go
```
Then modify the created `.go` file which looks like:
```go
package migrations
import (
"database/sql"
"fmt"
"github.com/pressly/goose"
)
func init() {
goose.AddMigration(upComplexUpdate, downComplexUpdate)
}
func upComplexUpdate(tx *sql.Tx) error {
// This code is executed when the migration is applied.
return nil
}
func downComplexUpdate(tx *sql.Tx) error {
// This code is executed when the migration is rolled back.
return nil
}
```
You __must__ import the package in `/cmd/goose/main.go` so `func init()` gets called.
#### Database limitations
- SQLite3 does NOT support `ALTER TABLE table_name DROP COLUMN` - you would have to rename the column or drop the table
entirely and recreate it. ([example](https://github.com/matrix-org/dendrite/blob/master/userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.sql))
More information: [sqlite.org](https://www.sqlite.org/lang_altertable.html)

View file

@ -1,154 +0,0 @@
// This is custom goose binary
package main
import (
"flag"
"fmt"
"log"
"os"
"github.com/pressly/goose"
pgusers "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
slusers "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
)
const (
AppService = "appservice"
FederationSender = "federationapi"
KeyServer = "keyserver"
MediaAPI = "mediaapi"
RoomServer = "roomserver"
SigningKeyServer = "signingkeyserver"
SyncAPI = "syncapi"
UserAPI = "userapi"
)
var (
dir = flags.String("dir", "", "directory with migration files")
flags = flag.NewFlagSet("goose", flag.ExitOnError)
component = flags.String("component", "", "dendrite component name")
knownDBs = []string{
AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPI,
}
)
// nolint: gocyclo
func main() {
err := flags.Parse(os.Args[1:])
if err != nil {
panic(err.Error())
}
args := flags.Args()
if len(args) < 3 {
fmt.Println(
`Usage: goose [OPTIONS] DRIVER DBSTRING COMMAND
Drivers:
postgres
sqlite3
Examples:
goose -component roomserver sqlite3 ./roomserver.db status
goose -component roomserver sqlite3 ./roomserver.db up
goose -component roomserver postgres "user=dendrite dbname=dendrite sslmode=disable" status
Options:
-component string
Dendrite component name e.g roomserver, signingkeyserver, clientapi, syncapi
-table string
migrations table name (default "goose_db_version")
-h print help
-v enable verbose mode
-dir string
directory with migration files, only relevant when creating new migrations.
-version
print version
Commands:
up Migrate the DB to the most recent version available
up-by-one Migrate the DB up by 1
up-to VERSION Migrate the DB to a specific VERSION
down Roll back the version by 1
down-to VERSION Roll back to a specific VERSION
redo Re-run the latest migration
reset Roll back all migrations
status Dump the migration status for the current DB
version Print the current version of the database
create NAME [sql|go] Creates new migration file with the current timestamp
fix Apply sequential ordering to migrations`,
)
return
}
engine := args[0]
if engine != "sqlite3" && engine != "postgres" {
fmt.Println("engine must be one of 'sqlite3' or 'postgres'")
return
}
knownComponent := false
for _, c := range knownDBs {
if c == *component {
knownComponent = true
break
}
}
if !knownComponent {
fmt.Printf("component must be one of %v\n", knownDBs)
return
}
if engine == "sqlite3" {
loadSQLiteDeltas(*component)
} else {
loadPostgresDeltas(*component)
}
dbstring, command := args[1], args[2]
db, err := goose.OpenDBWithDriver(engine, dbstring)
if err != nil {
log.Fatalf("goose: failed to open DB: %v\n", err)
}
defer func() {
if err := db.Close(); err != nil {
log.Fatalf("goose: failed to close DB: %v\n", err)
}
}()
arguments := []string{}
if len(args) > 3 {
arguments = append(arguments, args[3:]...)
}
// goose demands a directory even though we don't use it for upgrades
d := *dir
if d == "" {
d = os.TempDir()
}
if err := goose.Run(command, db, d, arguments...); err != nil {
log.Fatalf("goose %v: %v", command, err)
}
}
func loadSQLiteDeltas(component string) {
switch component {
case UserAPI:
slusers.LoadFromGoose()
}
}
func loadPostgresDeltas(component string) {
switch component {
case UserAPI:
pgusers.LoadFromGoose()
}
}

View file

@ -15,23 +15,13 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
) )
func LoadFromGoose() { func UpRemoveRoomsTable(ctx context.Context, tx *sql.Tx) error {
goose.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable) _, err := tx.ExecContext(ctx, `
}
func LoadRemoveRoomsTable(m *sqlutil.Migrations) {
m.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable)
}
func UpRemoveRoomsTable(tx *sql.Tx) error {
_, err := tx.Exec(`
DROP TABLE IF EXISTS federationsender_rooms; DROP TABLE IF EXISTS federationsender_rooms;
`) `)
if err != nil { if err != nil {

View file

@ -82,9 +82,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrations() m := sqlutil.NewMigrator(d.db)
deltas.LoadRemoveRoomsTable(m) m.AddMigrations(sqlutil.Migration{
if err = m.RunDeltas(d.db, dbProperties); err != nil { Version: "federationsender: drop federationsender_rooms",
Up: deltas.UpRemoveRoomsTable,
})
err = m.Up(base.Context())
if err != nil {
return nil, err return nil, err
} }
d.Database = shared.Database{ d.Database = shared.Database{

View file

@ -15,23 +15,13 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
) )
func LoadFromGoose() { func UpRemoveRoomsTable(ctx context.Context, tx *sql.Tx) error {
goose.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable) _, err := tx.ExecContext(ctx, `
}
func LoadRemoveRoomsTable(m *sqlutil.Migrations) {
m.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable)
}
func UpRemoveRoomsTable(tx *sql.Tx) error {
_, err := tx.Exec(`
DROP TABLE IF EXISTS federationsender_rooms; DROP TABLE IF EXISTS federationsender_rooms;
`) `)
if err != nil { if err != nil {

View file

@ -81,9 +81,13 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrations() m := sqlutil.NewMigrator(d.db)
deltas.LoadRemoveRoomsTable(m) m.AddMigrations(sqlutil.Migration{
if err = m.RunDeltas(d.db, dbProperties); err != nil { Version: "federationsender: drop federationsender_rooms",
Up: deltas.UpRemoveRoomsTable,
})
err = m.Up(base.Context())
if err != nil {
return nil, err return nil, err
} }
d.Database = shared.Database{ d.Database = shared.Database{

1
go.mod
View file

@ -37,7 +37,6 @@ require (
github.com/opentracing/opentracing-go v1.2.0 github.com/opentracing/opentracing-go v1.2.0
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/pressly/goose v2.7.0+incompatible
github.com/prometheus/client_golang v1.12.2 github.com/prometheus/client_golang v1.12.2
github.com/sirupsen/logrus v1.8.1 github.com/sirupsen/logrus v1.8.1
github.com/stretchr/testify v1.7.1 github.com/stretchr/testify v1.7.1

2
go.sum
View file

@ -432,8 +432,6 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pressly/goose v2.7.0+incompatible h1:PWejVEv07LCerQEzMMeAtjuyCKbyprZ/LBa6K5P0OCQ=
github.com/pressly/goose v2.7.0+incompatible/go.mod h1:m+QHWCqxR3k8D9l7qfzuC/djtlfzxr34mozWDYEu1z8=
github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=

View file

@ -1,130 +1,142 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlutil package sqlutil
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"runtime" "sync"
"sort" "time"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/internal"
"github.com/pressly/goose" "github.com/sirupsen/logrus"
) )
type Migrations struct { const createDBMigrationsSQL = "" +
registeredGoMigrations map[int64]*goose.Migration "CREATE TABLE IF NOT EXISTS db_migrations (" +
" version TEXT PRIMARY KEY NOT NULL," +
" time TEXT NOT NULL," +
" dendrite_version TEXT NOT NULL" +
");"
const insertVersionSQL = "" +
"INSERT INTO db_migrations (version, time, dendrite_version)" +
" VALUES ($1, $2, $3)"
const selectDBMigrationsSQL = "SELECT version FROM db_migrations"
// Migration defines a migration to be run.
type Migration struct {
// Version is a simple description/name of this migration.
Version string
// Up defines the function to execute for an upgrade.
Up func(ctx context.Context, txn *sql.Tx) error
// Down defines the function to execute for a downgrade (not implemented yet).
Down func(ctx context.Context, txn *sql.Tx) error
} }
func NewMigrations() *Migrations { // Migrator
return &Migrations{ type Migrator struct {
registeredGoMigrations: make(map[int64]*goose.Migration), db *sql.DB
migrations []Migration
knownMigrations map[string]struct{}
mutex *sync.Mutex
}
// NewMigrator creates a new DB migrator.
func NewMigrator(db *sql.DB) *Migrator {
return &Migrator{
db: db,
migrations: []Migration{},
knownMigrations: make(map[string]struct{}),
mutex: &sync.Mutex{},
} }
} }
// Copy-pasted from goose directly to store migrations into a map we control // AddMigrations appends migrations to the list of migrations. Migrations are executed
// in the order they are added to the list. De-duplicates migrations using their Version field.
// AddMigration adds a migration. func (m *Migrator) AddMigrations(migrations ...Migration) {
func (m *Migrations) AddMigration(up func(*sql.Tx) error, down func(*sql.Tx) error) { m.mutex.Lock()
_, filename, _, _ := runtime.Caller(1) defer m.mutex.Unlock()
m.AddNamedMigration(filename, up, down) for _, mig := range migrations {
if _, ok := m.knownMigrations[mig.Version]; !ok {
m.migrations = append(m.migrations, mig)
m.knownMigrations[mig.Version] = struct{}{}
}
}
} }
// AddNamedMigration : Add a named migration. // Up executes all migrations in order they were added.
func (m *Migrations) AddNamedMigration(filename string, up func(*sql.Tx) error, down func(*sql.Tx) error) { func (m *Migrator) Up(ctx context.Context) error {
v, _ := goose.NumericComponent(filename) var (
migration := &goose.Migration{Version: v, Next: -1, Previous: -1, Registered: true, UpFn: up, DownFn: down, Source: filename} err error
dendriteVersion = internal.VersionString()
if existing, ok := m.registeredGoMigrations[v]; ok { )
panic(fmt.Sprintf("failed to add migration %q: version conflicts with %q", filename, existing.Source)) // ensure there is a table for known migrations
} executedMigrations, err := m.ExecutedMigrations(ctx)
m.registeredGoMigrations[v] = migration
}
// RunDeltas up to the latest version.
func (m *Migrations) RunDeltas(db *sql.DB, props *config.DatabaseOptions) error {
maxVer := goose.MaxVersion
minVer := int64(0)
migrations, err := m.collect(minVer, maxVer)
if err != nil { if err != nil {
return fmt.Errorf("runDeltas: Failed to collect migrations: %w", err) return fmt.Errorf("unable to create/get migrations: %w", err)
}
if props.ConnectionString.IsPostgres() {
if err = goose.SetDialect("postgres"); err != nil {
return err
}
} else if props.ConnectionString.IsSQLite() {
if err = goose.SetDialect("sqlite3"); err != nil {
return err
}
} else {
return fmt.Errorf("unknown connection string: %s", props.ConnectionString)
}
for {
current, err := goose.EnsureDBVersion(db)
if err != nil {
return fmt.Errorf("runDeltas: Failed to EnsureDBVersion: %w", err)
} }
next, err := migrations.Next(current) return WithTransaction(m.db, func(txn *sql.Tx) error {
for i := range m.migrations {
now := time.Now().UTC().Format(time.RFC3339)
migration := m.migrations[i]
logrus.Debugf("Executing database migration '%s'", migration.Version)
// Skip migration if it was already executed
if _, ok := executedMigrations[migration.Version]; ok {
continue
}
err = migration.Up(ctx, txn)
if err != nil { if err != nil {
if err == goose.ErrNoNextVersion { return fmt.Errorf("unable to execute migration '%s': %w", migration.Version, err)
}
_, err = txn.ExecContext(ctx, insertVersionSQL,
migration.Version,
now,
dendriteVersion,
)
if err != nil {
return fmt.Errorf("unable to insert executed migrations: %w", err)
}
}
return nil return nil
})
} }
return fmt.Errorf("runDeltas: Failed to load next migration to %+v : %w", next, err) // ExecutedMigrations returns a map with already executed migrations in addition to creating the
} // migrations table, if it doesn't exist.
func (m *Migrator) ExecutedMigrations(ctx context.Context) (map[string]struct{}, error) {
if err = next.Up(db); err != nil { result := make(map[string]struct{})
return fmt.Errorf("runDeltas: Failed run migration: %w", err) _, err := m.db.ExecContext(ctx, createDBMigrationsSQL)
}
}
}
func (m *Migrations) collect(current, target int64) (goose.Migrations, error) {
var migrations goose.Migrations
// Go migrations registered via goose.AddMigration().
for _, migration := range m.registeredGoMigrations {
v, err := goose.NumericComponent(migration.Source)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("unable to create db_migrations: %w", err)
} }
if versionFilter(v, current, target) { rows, err := m.db.QueryContext(ctx, selectDBMigrationsSQL)
migrations = append(migrations, migration) if err != nil {
return nil, fmt.Errorf("unable to query db_migrations: %w", err)
} }
defer internal.CloseAndLogIfError(ctx, rows, "ExecutedMigrations: rows.close() failed")
var version string
for rows.Next() {
if err = rows.Scan(&version); err != nil {
return nil, fmt.Errorf("unable to scan version: %w", err)
}
result[version] = struct{}{}
} }
migrations = sortAndConnectMigrations(migrations) return result, rows.Err()
return migrations, nil
}
func sortAndConnectMigrations(migrations goose.Migrations) goose.Migrations {
sort.Sort(migrations)
// now that we're sorted in the appropriate direction,
// populate next and previous for each migration
for i, m := range migrations {
prev := int64(-1)
if i > 0 {
prev = migrations[i-1].Version
migrations[i-1].Next = m.Version
}
migrations[i].Previous = prev
}
return migrations
}
func versionFilter(v, current, target int64) bool {
if target > current {
return v > current && v <= target
}
if target < current {
return v <= current && v > target
}
return false
} }

View file

@ -0,0 +1,112 @@
package sqlutil_test
import (
"context"
"database/sql"
"fmt"
"reflect"
"testing"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/test"
_ "github.com/mattn/go-sqlite3"
)
var dummyMigrations = []sqlutil.Migration{
{
Version: "init",
Up: func(ctx context.Context, txn *sql.Tx) error {
_, err := txn.ExecContext(ctx, "CREATE TABLE IF NOT EXISTS dummy ( test TEXT );")
return err
},
},
{
Version: "v2",
Up: func(ctx context.Context, txn *sql.Tx) error {
_, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test2 TEXT;")
return err
},
},
{
Version: "v2", // duplicate, this migration will be skipped
Up: func(ctx context.Context, txn *sql.Tx) error {
_, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test2 TEXT;")
return err
},
},
{
Version: "multiple execs",
Up: func(ctx context.Context, txn *sql.Tx) error {
_, err := txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test3 TEXT;")
if err != nil {
return err
}
_, err = txn.ExecContext(ctx, "ALTER TABLE dummy ADD COLUMN test4 TEXT;")
return err
},
},
}
var failMigration = sqlutil.Migration{
Version: "iFail",
Up: func(ctx context.Context, txn *sql.Tx) error {
return fmt.Errorf("iFail")
},
Down: nil,
}
func Test_migrations_Up(t *testing.T) {
withFail := append(dummyMigrations, failMigration)
tests := []struct {
name string
migrations []sqlutil.Migration
wantResult map[string]struct{}
wantErr bool
}{
{
name: "dummy migration",
migrations: dummyMigrations,
wantResult: map[string]struct{}{
"init": {},
"v2": {},
"multiple execs": {},
},
},
{
name: "with fail",
migrations: withFail,
wantErr: true,
},
}
ctx := context.Background()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
conStr, close := test.PrepareDBConnectionString(t, dbType)
defer close()
driverName := "sqlite3"
if dbType == test.DBTypePostgres {
driverName = "postgres"
}
db, err := sql.Open(driverName, conStr)
if err != nil {
t.Errorf("unable to open database: %v", err)
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(tt.migrations...)
if err = m.Up(ctx); (err != nil) != tt.wantErr {
t.Errorf("Up() error = %v, wantErr %v", err, tt.wantErr)
}
result, err := m.ExecutedMigrations(ctx)
if err != nil {
t.Errorf("unable to get executed migrations: %v", err)
}
if !tt.wantErr && !reflect.DeepEqual(result, tt.wantResult) {
t.Errorf("expected: %+v, got %v", tt.wantResult, result)
}
})
})
}
}

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/keyserver/types" "github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -66,6 +67,16 @@ func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, erro
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "keyserver: cross signing signature indexes",
Up: deltas.UpFixCrossSigningSignatureIndexes,
})
if err = m.Up(context.Background()); err != nil {
return nil, err
}
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL}, {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL},
{&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL}, {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL},

View file

@ -15,37 +15,27 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
) )
func LoadFromGoose() { func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
goose.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges)
}
func LoadRefactorKeyChanges(m *sqlutil.Migrations) {
m.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges)
}
func UpRefactorKeyChanges(tx *sql.Tx) error {
// start counting from the last max offset, else 0. We need to do a count(*) first to see if there // start counting from the last max offset, else 0. We need to do a count(*) first to see if there
// even are entries in this table to know if we can query for log_offset. Without the count then // even are entries in this table to know if we can query for log_offset. Without the count then
// the query to SELECT the max log offset fails on new Dendrite instances as log_offset doesn't // the query to SELECT the max log offset fails on new Dendrite instances as log_offset doesn't
// exist on that table. Even though we discard the error, the txn is tainted and gets aborted :/ // exist on that table. Even though we discard the error, the txn is tainted and gets aborted :/
var count int var count int
_ = tx.QueryRow(`SELECT count(*) FROM keyserver_key_changes`).Scan(&count) _ = tx.QueryRowContext(ctx, `SELECT count(*) FROM keyserver_key_changes`).Scan(&count)
if count > 0 { if count > 0 {
var maxOffset int64 var maxOffset int64
_ = tx.QueryRow(`SELECT coalesce(MAX(log_offset), 0) AS offset FROM keyserver_key_changes`).Scan(&maxOffset) _ = tx.QueryRowContext(ctx, `SELECT coalesce(MAX(log_offset), 0) AS offset FROM keyserver_key_changes`).Scan(&maxOffset)
if _, err := tx.Exec(fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq START %d`, maxOffset)); err != nil { if _, err := tx.ExecContext(ctx, fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq START %d`, maxOffset)); err != nil {
return fmt.Errorf("failed to CREATE SEQUENCE for key changes, starting at %d: %s", maxOffset, err) return fmt.Errorf("failed to CREATE SEQUENCE for key changes, starting at %d: %s", maxOffset, err)
} }
} }
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
-- make the new table -- make the new table
DROP TABLE IF EXISTS keyserver_key_changes; DROP TABLE IF EXISTS keyserver_key_changes;
CREATE TABLE IF NOT EXISTS keyserver_key_changes ( CREATE TABLE IF NOT EXISTS keyserver_key_changes (
@ -60,8 +50,8 @@ func UpRefactorKeyChanges(tx *sql.Tx) error {
return nil return nil
} }
func DownRefactorKeyChanges(tx *sql.Tx) error { func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
-- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
DROP SEQUENCE IF EXISTS keyserver_key_changes_seq; DROP SEQUENCE IF EXISTS keyserver_key_changes_seq;
DROP TABLE IF EXISTS keyserver_key_changes; DROP TABLE IF EXISTS keyserver_key_changes;

View file

@ -15,18 +15,13 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
func LoadFixCrossSigningSignatureIndexes(m *sqlutil.Migrations) { func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
m.AddMigration(UpFixCrossSigningSignatureIndexes, DownFixCrossSigningSignatureIndexes) _, err := tx.ExecContext(ctx, `
}
func UpFixCrossSigningSignatureIndexes(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey; ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id); ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id);
@ -38,8 +33,8 @@ func UpFixCrossSigningSignatureIndexes(tx *sql.Tx) error {
return nil return nil
} }
func DownFixCrossSigningSignatureIndexes(tx *sql.Tx) error { func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey; ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, target_user_id, target_key_id); ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, target_user_id, target_key_id);

View file

@ -19,6 +19,8 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
) )
@ -55,9 +57,25 @@ func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
db: db, db: db,
} }
_, err := db.Exec(keyChangesSchema) _, err := db.Exec(keyChangesSchema)
if err != nil {
return s, err return s, err
} }
// TODO: Remove when we are sure we are not having goose artefacts in the db
// This forces an error, which indicates the migration is already applied, since the
// column partition was removed from the table
err = db.QueryRow("SELECT partition FROM keyserver_key_changes LIMIT 1;").Scan()
if err == nil {
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "keyserver: refactor key changes",
Up: deltas.UpRefactorKeyChanges,
})
return s, m.Up(context.Background())
}
return s, nil
}
func (s *keyChangesStatements) Prepare() (err error) { func (s *keyChangesStatements) Prepare() (err error) {
if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil { if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil {
return err return err

View file

@ -16,7 +16,6 @@ package postgres
import ( import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas"
"github.com/matrix-org/dendrite/keyserver/storage/shared" "github.com/matrix-org/dendrite/keyserver/storage/shared"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
@ -53,12 +52,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrations()
deltas.LoadRefactorKeyChanges(m)
deltas.LoadFixCrossSigningSignatureIndexes(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err = kc.Prepare(); err != nil { if err = kc.Prepare(); err != nil {
return nil, err return nil, err
} }

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/keyserver/types" "github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -65,6 +66,15 @@ func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "keyserver: cross signing signature indexes",
Up: deltas.UpFixCrossSigningSignatureIndexes,
})
if err = m.Up(context.Background()); err != nil {
return nil, err
}
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL}, {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL},
{&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL}, {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL},

View file

@ -15,28 +15,18 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
) )
func LoadFromGoose() { func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
goose.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges)
}
func LoadRefactorKeyChanges(m *sqlutil.Migrations) {
m.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges)
}
func UpRefactorKeyChanges(tx *sql.Tx) error {
// start counting from the last max offset, else 0. // start counting from the last max offset, else 0.
var maxOffset int64 var maxOffset int64
var userID string var userID string
_ = tx.QueryRow(`SELECT user_id, MAX(log_offset) FROM keyserver_key_changes GROUP BY user_id`).Scan(&userID, &maxOffset) _ = tx.QueryRowContext(ctx, `SELECT user_id, MAX(log_offset) FROM keyserver_key_changes GROUP BY user_id`).Scan(&userID, &maxOffset)
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
-- make the new table -- make the new table
DROP TABLE IF EXISTS keyserver_key_changes; DROP TABLE IF EXISTS keyserver_key_changes;
CREATE TABLE IF NOT EXISTS keyserver_key_changes ( CREATE TABLE IF NOT EXISTS keyserver_key_changes (
@ -51,14 +41,14 @@ func UpRefactorKeyChanges(tx *sql.Tx) error {
} }
// to start counting from maxOffset, insert a row with that value // to start counting from maxOffset, insert a row with that value
if userID != "" { if userID != "" {
_, err = tx.Exec(`INSERT INTO keyserver_key_changes(change_id, user_id) VALUES($1, $2)`, maxOffset, userID) _, err = tx.ExecContext(ctx, `INSERT INTO keyserver_key_changes(change_id, user_id) VALUES($1, $2)`, maxOffset, userID)
return err return err
} }
return nil return nil
} }
func DownRefactorKeyChanges(tx *sql.Tx) error { func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
-- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
DROP TABLE IF EXISTS keyserver_key_changes; DROP TABLE IF EXISTS keyserver_key_changes;
CREATE TABLE IF NOT EXISTS keyserver_key_changes ( CREATE TABLE IF NOT EXISTS keyserver_key_changes (

View file

@ -15,18 +15,13 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
func LoadFixCrossSigningSignatureIndexes(m *sqlutil.Migrations) { func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
m.AddMigration(UpFixCrossSigningSignatureIndexes, DownFixCrossSigningSignatureIndexes) _, err := tx.ExecContext(ctx, `
}
func UpFixCrossSigningSignatureIndexes(tx *sql.Tx) error {
_, err := tx.Exec(`
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp ( CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp (
origin_user_id TEXT NOT NULL, origin_user_id TEXT NOT NULL,
origin_key_id TEXT NOT NULL, origin_key_id TEXT NOT NULL,
@ -50,8 +45,8 @@ func UpFixCrossSigningSignatureIndexes(tx *sql.Tx) error {
return nil return nil
} }
func DownFixCrossSigningSignatureIndexes(tx *sql.Tx) error { func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp ( CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp (
origin_user_id TEXT NOT NULL, origin_user_id TEXT NOT NULL,
origin_key_id TEXT NOT NULL, origin_key_id TEXT NOT NULL,

View file

@ -19,6 +19,8 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
) )
@ -53,8 +55,24 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
db: db, db: db,
} }
_, err := db.Exec(keyChangesSchema) _, err := db.Exec(keyChangesSchema)
if err != nil {
return s, err return s, err
} }
// TODO: Remove when we are sure we are not having goose artefacts in the db
// This forces an error, which indicates the migration is already applied, since the
// column partition was removed from the table
err = db.QueryRow("SELECT partition FROM keyserver_key_changes LIMIT 1;").Scan()
if err == nil {
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "keyserver: refactor key changes",
Up: deltas.UpRefactorKeyChanges,
})
return s, m.Up(context.Background())
}
return s, nil
}
func (s *keyChangesStatements) Prepare() (err error) { func (s *keyChangesStatements) Prepare() (err error) {
if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil { if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil {

View file

@ -17,7 +17,6 @@ package sqlite3
import ( import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/shared" "github.com/matrix-org/dendrite/keyserver/storage/shared"
"github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
) )
@ -52,12 +51,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
return nil, err return nil, err
} }
m := sqlutil.NewMigrations()
deltas.LoadRefactorKeyChanges(m)
deltas.LoadFixCrossSigningSignatureIndexes(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
if err = kc.Prepare(); err != nil { if err = kc.Prepare(); err != nil {
return nil, err return nil, err
} }

View file

@ -15,32 +15,21 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
) )
func LoadFromGoose() { func UpAddForgottenColumn(ctx context.Context, tx *sql.Tx) error {
goose.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn) _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_membership ADD COLUMN IF NOT EXISTS forgotten BOOLEAN NOT NULL DEFAULT false;`)
goose.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor)
}
func LoadAddForgottenColumn(m *sqlutil.Migrations) {
m.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn)
}
func UpAddForgottenColumn(tx *sql.Tx) error {
_, err := tx.Exec(`ALTER TABLE roomserver_membership ADD COLUMN IF NOT EXISTS forgotten BOOLEAN NOT NULL DEFAULT false;`)
if err != nil { if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err) return fmt.Errorf("failed to execute upgrade: %w", err)
} }
return nil return nil
} }
func DownAddForgottenColumn(tx *sql.Tx) error { func DownAddForgottenColumn(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(`ALTER TABLE roomserver_membership DROP COLUMN IF EXISTS forgotten;`) _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_membership DROP COLUMN IF EXISTS forgotten;`)
if err != nil { if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err) return fmt.Errorf("failed to execute downgrade: %w", err)
} }

View file

@ -15,11 +15,11 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -36,48 +36,44 @@ type stateBlockData struct {
EventNIDs types.EventNIDs EventNIDs types.EventNIDs
} }
func LoadStateBlocksRefactor(m *sqlutil.Migrations) {
m.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor)
}
// nolint:gocyclo // nolint:gocyclo
func UpStateBlocksRefactor(tx *sql.Tx) error { func UpStateBlocksRefactor(ctx context.Context, tx *sql.Tx) error {
logrus.Warn("Performing state storage upgrade. Please wait, this may take some time!") logrus.Warn("Performing state storage upgrade. Please wait, this may take some time!")
defer logrus.Warn("State storage upgrade complete") defer logrus.Warn("State storage upgrade complete")
var snapshotcount int var snapshotcount int
var maxsnapshotid int var maxsnapshotid int
var maxblockid int var maxblockid int
if err := tx.QueryRow(`SELECT COUNT(DISTINCT state_snapshot_nid) FROM roomserver_state_snapshots;`).Scan(&snapshotcount); err != nil { if err := tx.QueryRowContext(ctx, `SELECT COUNT(DISTINCT state_snapshot_nid) FROM roomserver_state_snapshots;`).Scan(&snapshotcount); err != nil {
return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err) return fmt.Errorf("tx.QueryRowContext.Scan (count snapshots): %w", err)
} }
if err := tx.QueryRow(`SELECT COALESCE(MAX(state_snapshot_nid),0) FROM roomserver_state_snapshots;`).Scan(&maxsnapshotid); err != nil { if err := tx.QueryRowContext(ctx, `SELECT COALESCE(MAX(state_snapshot_nid),0) FROM roomserver_state_snapshots;`).Scan(&maxsnapshotid); err != nil {
return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err) return fmt.Errorf("tx.QueryRowContext.Scan (count snapshots): %w", err)
} }
if err := tx.QueryRow(`SELECT COALESCE(MAX(state_block_nid),0) FROM roomserver_state_block;`).Scan(&maxblockid); err != nil { if err := tx.QueryRowContext(ctx, `SELECT COALESCE(MAX(state_block_nid),0) FROM roomserver_state_block;`).Scan(&maxblockid); err != nil {
return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err) return fmt.Errorf("tx.QueryRowContext.Scan (count snapshots): %w", err)
} }
maxsnapshotid++ maxsnapshotid++
maxblockid++ maxblockid++
if _, err := tx.Exec(`ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil { if _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil {
return fmt.Errorf("tx.Exec: %w", err) return fmt.Errorf("tx.ExecContext: %w", err)
} }
if _, err := tx.Exec(`ALTER TABLE roomserver_state_snapshots RENAME TO _roomserver_state_snapshots;`); err != nil { if _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_state_snapshots RENAME TO _roomserver_state_snapshots;`); err != nil {
return fmt.Errorf("tx.Exec: %w", err) return fmt.Errorf("tx.ExecContext: %w", err)
} }
// We create new sequences starting with the maximum state snapshot and block NIDs. // We create new sequences starting with the maximum state snapshot and block NIDs.
// This means that all newly created snapshots and blocks by the migration will have // This means that all newly created snapshots and blocks by the migration will have
// NIDs higher than these values, so that when we come to update the references to // NIDs higher than these values, so that when we come to update the references to
// these NIDs using UPDATE statements, we can guarantee we are only ever updating old // these NIDs using UPDATE statements, we can guarantee we are only ever updating old
// values and not accidentally overwriting new ones. // values and not accidentally overwriting new ones.
if _, err := tx.Exec(fmt.Sprintf(`CREATE SEQUENCE roomserver_state_block_nid_sequence START WITH %d;`, maxblockid)); err != nil { if _, err := tx.ExecContext(ctx, fmt.Sprintf(`CREATE SEQUENCE roomserver_state_block_nid_sequence START WITH %d;`, maxblockid)); err != nil {
return fmt.Errorf("tx.Exec: %w", err) return fmt.Errorf("tx.ExecContext: %w", err)
} }
if _, err := tx.Exec(fmt.Sprintf(`CREATE SEQUENCE roomserver_state_snapshot_nid_sequence START WITH %d;`, maxsnapshotid)); err != nil { if _, err := tx.ExecContext(ctx, fmt.Sprintf(`CREATE SEQUENCE roomserver_state_snapshot_nid_sequence START WITH %d;`, maxsnapshotid)); err != nil {
return fmt.Errorf("tx.Exec: %w", err) return fmt.Errorf("tx.ExecContext: %w", err)
} }
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS roomserver_state_block ( CREATE TABLE IF NOT EXISTS roomserver_state_block (
state_block_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_block_nid_sequence'), state_block_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_block_nid_sequence'),
state_block_hash BYTEA UNIQUE, state_block_hash BYTEA UNIQUE,
@ -87,7 +83,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
if err != nil { if err != nil {
return fmt.Errorf("tx.Exec (create blocks table): %w", err) return fmt.Errorf("tx.Exec (create blocks table): %w", err)
} }
_, err = tx.Exec(` _, err = tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
state_snapshot_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_snapshot_nid_sequence'), state_snapshot_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_snapshot_nid_sequence'),
state_snapshot_hash BYTEA UNIQUE, state_snapshot_hash BYTEA UNIQUE,
@ -104,7 +100,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
// in question a state snapshot NID of 0 to indicate 'no snapshot'. // in question a state snapshot NID of 0 to indicate 'no snapshot'.
// If we don't do this, we'll fail the assertions later on which try to ensure we didn't forget // If we don't do this, we'll fail the assertions later on which try to ensure we didn't forget
// any snapshots. // any snapshots.
_, err = tx.Exec( _, err = tx.ExecContext(ctx,
`UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE event_type_nid = $1 AND event_state_key_nid = $2`, `UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE event_type_nid = $1 AND event_state_key_nid = $2`,
types.MRoomCreateNID, types.EmptyStateKeyNID, types.MRoomCreateNID, types.EmptyStateKeyNID,
) )
@ -115,7 +111,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
batchsize := 100 batchsize := 100
for batchoffset := 0; batchoffset < snapshotcount; batchoffset += batchsize { for batchoffset := 0; batchoffset < snapshotcount; batchoffset += batchsize {
var snapshotrows *sql.Rows var snapshotrows *sql.Rows
snapshotrows, err = tx.Query(` snapshotrows, err = tx.QueryContext(ctx, `
SELECT SELECT
state_snapshot_nid, state_snapshot_nid,
room_nid, room_nid,
@ -146,7 +142,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
state_block_nid; state_block_nid;
`, batchsize, batchoffset) `, batchsize, batchoffset)
if err != nil { if err != nil {
return fmt.Errorf("tx.Query: %w", err) return fmt.Errorf("tx.QueryContext: %w", err)
} }
logrus.Warnf("Rewriting snapshots %d-%d of %d...", batchoffset, batchoffset+batchsize, snapshotcount) logrus.Warnf("Rewriting snapshots %d-%d of %d...", batchoffset, batchoffset+batchsize, snapshotcount)
@ -183,7 +179,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
// fill in bad create snapshots // fill in bad create snapshots
for _, s := range badCreateSnapshots { for _, s := range badCreateSnapshots {
var createEventNID types.EventNID var createEventNID types.EventNID
err = tx.QueryRow( err = tx.QueryRowContext(ctx,
`SELECT event_nid FROM roomserver_events WHERE state_snapshot_nid = $1 AND event_type_nid = 1`, s.StateSnapshotNID, `SELECT event_nid FROM roomserver_events WHERE state_snapshot_nid = $1 AND event_type_nid = 1`, s.StateSnapshotNID,
).Scan(&createEventNID) ).Scan(&createEventNID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -208,7 +204,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
} }
var blocknid types.StateBlockNID var blocknid types.StateBlockNID
err = tx.QueryRow(` err = tx.QueryRowContext(ctx, `
INSERT INTO roomserver_state_block (state_block_hash, event_nids) INSERT INTO roomserver_state_block (state_block_hash, event_nids)
VALUES ($1, $2) VALUES ($1, $2)
ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$2 ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$2
@ -227,7 +223,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
} }
var newNID types.StateSnapshotNID var newNID types.StateSnapshotNID
err = tx.QueryRow(` err = tx.QueryRowContext(ctx, `
INSERT INTO roomserver_state_snapshots (state_snapshot_hash, room_nid, state_block_nids) INSERT INTO roomserver_state_snapshots (state_snapshot_hash, room_nid, state_block_nids)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$2 ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$2
@ -237,12 +233,12 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
return fmt.Errorf("tx.QueryRow.Scan (insert new snapshot): %w", err) return fmt.Errorf("tx.QueryRow.Scan (insert new snapshot): %w", err)
} }
if _, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newNID, snapshotdata.StateSnapshotNID, maxsnapshotid); err != nil { if _, err = tx.ExecContext(ctx, `UPDATE roomserver_events SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newNID, snapshotdata.StateSnapshotNID, maxsnapshotid); err != nil {
return fmt.Errorf("tx.Exec (update events): %w", err) return fmt.Errorf("tx.ExecContext (update events): %w", err)
} }
if _, err = tx.Exec(`UPDATE roomserver_rooms SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newNID, snapshotdata.StateSnapshotNID, maxsnapshotid); err != nil { if _, err = tx.ExecContext(ctx, `UPDATE roomserver_rooms SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newNID, snapshotdata.StateSnapshotNID, maxsnapshotid); err != nil {
return fmt.Errorf("tx.Exec (update rooms): %w", err) return fmt.Errorf("tx.ExecContext (update rooms): %w", err)
} }
} }
} }
@ -252,13 +248,13 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
// in roomserver_state_snapshots // in roomserver_state_snapshots
var count int64 var count int64
if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_events WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid).Scan(&count); err != nil { if err = tx.QueryRowContext(ctx, `SELECT COUNT(*) FROM roomserver_events WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid).Scan(&count); err != nil {
return fmt.Errorf("assertion query failed: %s", err) return fmt.Errorf("assertion query failed: %s", err)
} }
if count > 0 { if count > 0 {
var res sql.Result var res sql.Result
var c int64 var c int64
res, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid) res, err = tx.ExecContext(ctx, `UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return fmt.Errorf("failed to reset invalid state snapshots: %w", err) return fmt.Errorf("failed to reset invalid state snapshots: %w", err)
} }
@ -268,13 +264,13 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
return fmt.Errorf("expected to reset %d event(s) but only updated %d event(s)", count, c) return fmt.Errorf("expected to reset %d event(s) but only updated %d event(s)", count, c)
} }
} }
if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid).Scan(&count); err != nil { if err = tx.QueryRowContext(ctx, `SELECT COUNT(*) FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid).Scan(&count); err != nil {
return fmt.Errorf("assertion query failed: %s", err) return fmt.Errorf("assertion query failed: %s", err)
} }
if count > 0 { if count > 0 {
var debugRoomID string var debugRoomID string
var debugSnapNID, debugLastEventNID int64 var debugSnapNID, debugLastEventNID int64
err = tx.QueryRow( err = tx.QueryRowContext(ctx,
`SELECT room_id, state_snapshot_nid, last_event_sent_nid FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid, `SELECT room_id, state_snapshot_nid, last_event_sent_nid FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, maxsnapshotid,
).Scan(&debugRoomID, &debugSnapNID, &debugLastEventNID) ).Scan(&debugRoomID, &debugSnapNID, &debugLastEventNID)
if err != nil { if err != nil {
@ -291,13 +287,13 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
return fmt.Errorf("%d rooms exist in roomserver_rooms which have not been converted to a new state_snapshot_nid; this is a bug, please report", count) return fmt.Errorf("%d rooms exist in roomserver_rooms which have not been converted to a new state_snapshot_nid; this is a bug, please report", count)
} }
if _, err = tx.Exec(` if _, err = tx.ExecContext(ctx, `
DROP TABLE _roomserver_state_snapshots; DROP TABLE _roomserver_state_snapshots;
DROP SEQUENCE roomserver_state_snapshot_nid_seq; DROP SEQUENCE roomserver_state_snapshot_nid_seq;
`); err != nil { `); err != nil {
return fmt.Errorf("tx.Exec (delete old snapshot table): %w", err) return fmt.Errorf("tx.Exec (delete old snapshot table): %w", err)
} }
if _, err = tx.Exec(` if _, err = tx.ExecContext(ctx, `
DROP TABLE _roomserver_state_block; DROP TABLE _roomserver_state_block;
DROP SEQUENCE roomserver_state_block_nid_seq; DROP SEQUENCE roomserver_state_block_nid_seq;
`); err != nil { `); err != nil {
@ -307,6 +303,6 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
return nil return nil
} }
func DownStateBlocksRefactor(tx *sql.Tx) error { func DownStateBlocksRefactor(ctx context.Context, tx *sql.Tx) error {
panic("Downgrading state storage is not supported") panic("Downgrading state storage is not supported")
} }

View file

@ -23,6 +23,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -173,8 +174,16 @@ type membershipStatements struct {
func CreateMembershipTable(db *sql.DB) error { func CreateMembershipTable(db *sql.DB) error {
_, err := db.Exec(membershipSchema) _, err := db.Exec(membershipSchema)
if err != nil {
return err return err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "roomserver: add forgotten column",
Up: deltas.UpAddForgottenColumn,
})
return m.Up(context.Background())
}
func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) { func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
s := &membershipStatements{} s := &membershipStatements{}

View file

@ -21,7 +21,6 @@ import (
// Import the postgres database driver. // Import the postgres database driver.
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas" "github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas"
@ -45,18 +44,26 @@ func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache c
} }
// Create the tables. // Create the tables.
if err := d.create(db); err != nil { if err = d.create(db); err != nil {
return nil, err return nil, err
} }
// Then execute the migrations. By this point the tables are created with the latest // Special case, since this migration uses several tables, so it needs to
// schemas. // be sure that all tables are created first.
m := sqlutil.NewMigrations() // TODO: Remove when we are sure we are not having goose artefacts in the db
deltas.LoadAddForgottenColumn(m) // This forces an error, which indicates the migration is already applied, since the
deltas.LoadStateBlocksRefactor(m) // column event_nid was removed from the table
if err := m.RunDeltas(db, dbProperties); err != nil { err = db.QueryRow("SELECT event_nid FROM roomserver_state_block LIMIT 1;").Scan()
if err == nil {
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "roomserver: state blocks refactor",
Up: deltas.UpStateBlocksRefactor,
})
if err := m.Up(base.Context()); err != nil {
return nil, err return nil, err
} }
}
// Then prepare the statements. Now that the migrations have run, any columns referred // Then prepare the statements. Now that the migrations have run, any columns referred
// to in the database code should now exist. // to in the database code should now exist.

View file

@ -15,24 +15,13 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
) )
func LoadFromGoose() { func UpAddForgottenColumn(ctx context.Context, tx *sql.Tx) error {
goose.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn) _, err := tx.ExecContext(ctx, ` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp;
goose.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor)
}
func LoadAddForgottenColumn(m *sqlutil.Migrations) {
m.AddMigration(UpAddForgottenColumn, DownAddForgottenColumn)
}
func UpAddForgottenColumn(tx *sql.Tx) error {
_, err := tx.Exec(` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp;
CREATE TABLE IF NOT EXISTS roomserver_membership ( CREATE TABLE IF NOT EXISTS roomserver_membership (
room_nid INTEGER NOT NULL, room_nid INTEGER NOT NULL,
target_nid INTEGER NOT NULL, target_nid INTEGER NOT NULL,
@ -57,8 +46,8 @@ DROP TABLE roomserver_membership_tmp;`)
return nil return nil
} }
func DownAddForgottenColumn(tx *sql.Tx) error { func DownAddForgottenColumn(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp; _, err := tx.ExecContext(ctx, ` ALTER TABLE roomserver_membership RENAME TO roomserver_membership_tmp;
CREATE TABLE IF NOT EXISTS roomserver_membership ( CREATE TABLE IF NOT EXISTS roomserver_membership (
room_nid INTEGER NOT NULL, room_nid INTEGER NOT NULL,
target_nid INTEGER NOT NULL, target_nid INTEGER NOT NULL,

View file

@ -21,40 +21,35 @@ import (
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
func LoadStateBlocksRefactor(m *sqlutil.Migrations) {
m.AddMigration(UpStateBlocksRefactor, DownStateBlocksRefactor)
}
// nolint:gocyclo // nolint:gocyclo
func UpStateBlocksRefactor(tx *sql.Tx) error { func UpStateBlocksRefactor(ctx context.Context, tx *sql.Tx) error {
logrus.Warn("Performing state storage upgrade. Please wait, this may take some time!") logrus.Warn("Performing state storage upgrade. Please wait, this may take some time!")
defer logrus.Warn("State storage upgrade complete") defer logrus.Warn("State storage upgrade complete")
var maxsnapshotid int var maxsnapshotid int
var maxblockid int var maxblockid int
if err := tx.QueryRow(`SELECT IFNULL(MAX(state_snapshot_nid),0) FROM roomserver_state_snapshots;`).Scan(&maxsnapshotid); err != nil { if err := tx.QueryRowContext(ctx, `SELECT IFNULL(MAX(state_snapshot_nid),0) FROM roomserver_state_snapshots;`).Scan(&maxsnapshotid); err != nil {
return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err) return fmt.Errorf("tx.QueryRowContext.Scan (count snapshots): %w", err)
} }
if err := tx.QueryRow(`SELECT IFNULL(MAX(state_block_nid),0) FROM roomserver_state_block;`).Scan(&maxblockid); err != nil { if err := tx.QueryRowContext(ctx, `SELECT IFNULL(MAX(state_block_nid),0) FROM roomserver_state_block;`).Scan(&maxblockid); err != nil {
return fmt.Errorf("tx.QueryRow.Scan (count snapshots): %w", err) return fmt.Errorf("tx.QueryRowContext.Scan (count snapshots): %w", err)
} }
maxsnapshotid++ maxsnapshotid++
maxblockid++ maxblockid++
oldMaxSnapshotID := maxsnapshotid oldMaxSnapshotID := maxsnapshotid
if _, err := tx.Exec(`ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil { if _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_state_block RENAME TO _roomserver_state_block;`); err != nil {
return fmt.Errorf("tx.Exec: %w", err) return fmt.Errorf("tx.ExecContext: %w", err)
} }
if _, err := tx.Exec(`ALTER TABLE roomserver_state_snapshots RENAME TO _roomserver_state_snapshots;`); err != nil { if _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_state_snapshots RENAME TO _roomserver_state_snapshots;`); err != nil {
return fmt.Errorf("tx.Exec: %w", err) return fmt.Errorf("tx.ExecContext: %w", err)
} }
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS roomserver_state_block ( CREATE TABLE IF NOT EXISTS roomserver_state_block (
state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT, state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT,
state_block_hash BLOB UNIQUE, state_block_hash BLOB UNIQUE,
@ -62,9 +57,9 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
); );
`) `)
if err != nil { if err != nil {
return fmt.Errorf("tx.Exec: %w", err) return fmt.Errorf("tx.ExecContext: %w", err)
} }
_, err = tx.Exec(` _, err = tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT, state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
state_snapshot_hash BLOB UNIQUE, state_snapshot_hash BLOB UNIQUE,
@ -73,11 +68,11 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
); );
`) `)
if err != nil { if err != nil {
return fmt.Errorf("tx.Exec: %w", err) return fmt.Errorf("tx.ExecContext: %w", err)
} }
snapshotrows, err := tx.Query(`SELECT state_snapshot_nid, room_nid, state_block_nids FROM _roomserver_state_snapshots;`) snapshotrows, err := tx.QueryContext(ctx, `SELECT state_snapshot_nid, room_nid, state_block_nids FROM _roomserver_state_snapshots;`)
if err != nil { if err != nil {
return fmt.Errorf("tx.Query: %w", err) return fmt.Errorf("tx.QueryContext: %w", err)
} }
defer internal.CloseAndLogIfError(context.TODO(), snapshotrows, "rows.close() failed") defer internal.CloseAndLogIfError(context.TODO(), snapshotrows, "rows.close() failed")
for snapshotrows.Next() { for snapshotrows.Next() {
@ -99,7 +94,7 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
// in question a state snapshot NID of 0 to indicate 'no snapshot'. // in question a state snapshot NID of 0 to indicate 'no snapshot'.
// If we don't do this, we'll fail the assertions later on which try to ensure we didn't forget // If we don't do this, we'll fail the assertions later on which try to ensure we didn't forget
// any snapshots. // any snapshots.
_, err = tx.Exec( _, err = tx.ExecContext(ctx,
`UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE event_type_nid = $1 AND event_state_key_nid = $2 AND state_snapshot_nid = $3`, `UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE event_type_nid = $1 AND event_state_key_nid = $2 AND state_snapshot_nid = $3`,
types.MRoomCreateNID, types.EmptyStateKeyNID, snapshot, types.MRoomCreateNID, types.EmptyStateKeyNID, snapshot,
) )
@ -109,9 +104,9 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
} }
for _, block := range blocks { for _, block := range blocks {
if err = func() error { if err = func() error {
blockrows, berr := tx.Query(`SELECT event_nid FROM _roomserver_state_block WHERE state_block_nid = $1`, block) blockrows, berr := tx.QueryContext(ctx, `SELECT event_nid FROM _roomserver_state_block WHERE state_block_nid = $1`, block)
if berr != nil { if berr != nil {
return fmt.Errorf("tx.Query (event nids from old block): %w", berr) return fmt.Errorf("tx.QueryContext (event nids from old block): %w", berr)
} }
defer internal.CloseAndLogIfError(context.TODO(), blockrows, "rows.close() failed") defer internal.CloseAndLogIfError(context.TODO(), blockrows, "rows.close() failed")
events := types.EventNIDs{} events := types.EventNIDs{}
@ -129,14 +124,14 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
} }
var blocknid types.StateBlockNID var blocknid types.StateBlockNID
err = tx.QueryRow(` err = tx.QueryRowContext(ctx, `
INSERT INTO roomserver_state_block (state_block_nid, state_block_hash, event_nids) INSERT INTO roomserver_state_block (state_block_nid, state_block_hash, event_nids)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$3 ON CONFLICT (state_block_hash) DO UPDATE SET event_nids=$3
RETURNING state_block_nid RETURNING state_block_nid
`, maxblockid, events.Hash(), eventjson).Scan(&blocknid) `, maxblockid, events.Hash(), eventjson).Scan(&blocknid)
if err != nil { if err != nil {
return fmt.Errorf("tx.QueryRow.Scan (insert new block): %w", err) return fmt.Errorf("tx.QueryRowContext.Scan (insert new block): %w", err)
} }
maxblockid++ maxblockid++
newblocks = append(newblocks, blocknid) newblocks = append(newblocks, blocknid)
@ -151,22 +146,22 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
} }
var newsnapshot types.StateSnapshotNID var newsnapshot types.StateSnapshotNID
err = tx.QueryRow(` err = tx.QueryRowContext(ctx, `
INSERT INTO roomserver_state_snapshots (state_snapshot_nid, state_snapshot_hash, room_nid, state_block_nids) INSERT INTO roomserver_state_snapshots (state_snapshot_nid, state_snapshot_hash, room_nid, state_block_nids)
VALUES ($1, $2, $3, $4) VALUES ($1, $2, $3, $4)
ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$3 ON CONFLICT (state_snapshot_hash) DO UPDATE SET room_nid=$3
RETURNING state_snapshot_nid RETURNING state_snapshot_nid
`, maxsnapshotid, newblocks.Hash(), room, newblocksjson).Scan(&newsnapshot) `, maxsnapshotid, newblocks.Hash(), room, newblocksjson).Scan(&newsnapshot)
if err != nil { if err != nil {
return fmt.Errorf("tx.QueryRow.Scan (insert new snapshot): %w", err) return fmt.Errorf("tx.QueryRowContext.Scan (insert new snapshot): %w", err)
} }
maxsnapshotid++ maxsnapshotid++
_, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newsnapshot, snapshot, maxsnapshotid) _, err = tx.ExecContext(ctx, `UPDATE roomserver_events SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newsnapshot, snapshot, maxsnapshotid)
if err != nil { if err != nil {
return fmt.Errorf("tx.Exec (update events): %w", err) return fmt.Errorf("tx.ExecContext (update events): %w", err)
} }
if _, err = tx.Exec(`UPDATE roomserver_rooms SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newsnapshot, snapshot, maxsnapshotid); err != nil { if _, err = tx.ExecContext(ctx, `UPDATE roomserver_rooms SET state_snapshot_nid=$1 WHERE state_snapshot_nid=$2 AND state_snapshot_nid<$3`, newsnapshot, snapshot, maxsnapshotid); err != nil {
return fmt.Errorf("tx.Exec (update rooms): %w", err) return fmt.Errorf("tx.ExecContext (update rooms): %w", err)
} }
} }
} }
@ -175,13 +170,13 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
// If we do, this is a problem if Dendrite tries to load the snapshot as it will not exist // If we do, this is a problem if Dendrite tries to load the snapshot as it will not exist
// in roomserver_state_snapshots // in roomserver_state_snapshots
var count int64 var count int64
if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_events WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID).Scan(&count); err != nil { if err = tx.QueryRowContext(ctx, `SELECT COUNT(*) FROM roomserver_events WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID).Scan(&count); err != nil {
return fmt.Errorf("assertion query failed: %s", err) return fmt.Errorf("assertion query failed: %s", err)
} }
if count > 0 { if count > 0 {
var res sql.Result var res sql.Result
var c int64 var c int64
res, err = tx.Exec(`UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID) res, err = tx.ExecContext(ctx, `UPDATE roomserver_events SET state_snapshot_nid = 0 WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return fmt.Errorf("failed to reset invalid state snapshots: %w", err) return fmt.Errorf("failed to reset invalid state snapshots: %w", err)
} }
@ -191,23 +186,23 @@ func UpStateBlocksRefactor(tx *sql.Tx) error {
return fmt.Errorf("expected to reset %d event(s) but only updated %d event(s)", count, c) return fmt.Errorf("expected to reset %d event(s) but only updated %d event(s)", count, c)
} }
} }
if err = tx.QueryRow(`SELECT COUNT(*) FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID).Scan(&count); err != nil { if err = tx.QueryRowContext(ctx, `SELECT COUNT(*) FROM roomserver_rooms WHERE state_snapshot_nid < $1 AND state_snapshot_nid != 0`, oldMaxSnapshotID).Scan(&count); err != nil {
return fmt.Errorf("assertion query failed: %s", err) return fmt.Errorf("assertion query failed: %s", err)
} }
if count > 0 { if count > 0 {
return fmt.Errorf("%d rooms exist in roomserver_rooms which have not been converted to a new state_snapshot_nid; this is a bug, please report", count) return fmt.Errorf("%d rooms exist in roomserver_rooms which have not been converted to a new state_snapshot_nid; this is a bug, please report", count)
} }
if _, err = tx.Exec(`DROP TABLE _roomserver_state_snapshots;`); err != nil { if _, err = tx.ExecContext(ctx, `DROP TABLE _roomserver_state_snapshots;`); err != nil {
return fmt.Errorf("tx.Exec (delete old snapshot table): %w", err) return fmt.Errorf("tx.Exec (delete old snapshot table): %w", err)
} }
if _, err = tx.Exec(`DROP TABLE _roomserver_state_block;`); err != nil { if _, err = tx.ExecContext(ctx, `DROP TABLE _roomserver_state_block;`); err != nil {
return fmt.Errorf("tx.Exec (delete old block table): %w", err) return fmt.Errorf("tx.Exec (delete old block table): %w", err)
} }
return nil return nil
} }
func DownStateBlocksRefactor(tx *sql.Tx) error { func DownStateBlocksRefactor(ctx context.Context, tx *sql.Tx) error {
panic("Downgrading state storage is not supported") panic("Downgrading state storage is not supported")
} }

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -148,8 +149,16 @@ type membershipStatements struct {
func CreateMembershipTable(db *sql.DB) error { func CreateMembershipTable(db *sql.DB) error {
_, err := db.Exec(membershipSchema) _, err := db.Exec(membershipSchema)
if err != nil {
return err return err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "roomserver: add forgotten column",
Up: deltas.UpAddForgottenColumn,
})
return m.Up(context.Background())
}
func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) { func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
s := &membershipStatements{ s := &membershipStatements{

View file

@ -54,18 +54,26 @@ func Open(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, cache c
// db.SetMaxOpenConns(20) // db.SetMaxOpenConns(20)
// Create the tables. // Create the tables.
if err := d.create(db); err != nil { if err = d.create(db); err != nil {
return nil, err return nil, err
} }
// Then execute the migrations. By this point the tables are created with the latest // Special case, since this migration uses several tables, so it needs to
// schemas. // be sure that all tables are created first.
m := sqlutil.NewMigrations() // TODO: Remove when we are sure we are not having goose artefacts in the db
deltas.LoadAddForgottenColumn(m) // This forces an error, which indicates the migration is already applied, since the
deltas.LoadStateBlocksRefactor(m) // column event_nid was removed from the table
if err := m.RunDeltas(db, dbProperties); err != nil { err = db.QueryRow("SELECT event_nid FROM roomserver_state_block LIMIT 1;").Scan()
if err == nil {
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "roomserver: state blocks refactor",
Up: deltas.UpStateBlocksRefactor,
})
if err := m.Up(base.Context()); err != nil {
return nil, err return nil, err
} }
}
// Then prepare the statements. Now that the migrations have run, any columns referred // Then prepare the statements. Now that the migrations have run, any columns referred
// to in the database code should now exist. // to in the database code should now exist.

View file

@ -23,6 +23,7 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -133,6 +134,17 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "syncapi: add history visibility column (current_room_state)",
Up: deltas.UpAddHistoryVisibilityColumnCurrentRoomState,
})
err = m.Up(context.Background())
if err != nil {
return nil, err
}
if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil { if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil {
return nil, err return nil, err
} }

View file

@ -15,24 +15,13 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
) )
func LoadFromGoose() { func UpFixSequences(ctx context.Context, tx *sql.Tx) error {
goose.AddMigration(UpFixSequences, DownFixSequences) _, err := tx.ExecContext(ctx, `
goose.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
}
func LoadFixSequences(m *sqlutil.Migrations) {
m.AddMigration(UpFixSequences, DownFixSequences)
}
func UpFixSequences(tx *sql.Tx) error {
_, err := tx.Exec(`
-- We need to delete all of the existing receipts because the indexes -- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to -- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence. -- reuse existing stream IDs from a different sequence.
@ -49,8 +38,8 @@ func UpFixSequences(tx *sql.Tx) error {
return nil return nil
} }
func DownFixSequences(tx *sql.Tx) error { func DownFixSequences(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
-- We need to delete all of the existing receipts because the indexes -- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to -- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence. -- reuse existing stream IDs from a different sequence.

View file

@ -15,18 +15,13 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
func LoadRemoveSendToDeviceSentColumn(m *sqlutil.Migrations) { func UpRemoveSendToDeviceSentColumn(ctx context.Context, tx *sql.Tx) error {
m.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn) _, err := tx.ExecContext(ctx, `
}
func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE syncapi_send_to_device ALTER TABLE syncapi_send_to_device
DROP COLUMN IF EXISTS sent_by_token; DROP COLUMN IF EXISTS sent_by_token;
`) `)
@ -36,8 +31,8 @@ func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
return nil return nil
} }
func DownRemoveSendToDeviceSentColumn(tx *sql.Tx) error { func DownRemoveSendToDeviceSentColumn(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
ALTER TABLE syncapi_send_to_device ALTER TABLE syncapi_send_to_device
ADD COLUMN IF NOT EXISTS sent_by_token TEXT; ADD COLUMN IF NOT EXISTS sent_by_token TEXT;
`) `)

View file

@ -15,20 +15,24 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
func LoadAddHistoryVisibilityColumn(m *sqlutil.Migrations) { func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.Tx) error {
m.AddMigration(UpAddHistoryVisibilityColumn, DownAddHistoryVisibilityColumn) _, err := tx.ExecContext(ctx, `
}
func UpAddHistoryVisibilityColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE syncapi_output_room_events ADD COLUMN IF NOT EXISTS history_visibility SMALLINT NOT NULL DEFAULT 2; ALTER TABLE syncapi_output_room_events ADD COLUMN IF NOT EXISTS history_visibility SMALLINT NOT NULL DEFAULT 2;
UPDATE syncapi_output_room_events SET history_visibility = 4 WHERE type IN ('m.room.message', 'm.room.encrypted'); UPDATE syncapi_output_room_events SET history_visibility = 4 WHERE type IN ('m.room.message', 'm.room.encrypted');
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE syncapi_current_room_state ADD COLUMN IF NOT EXISTS history_visibility SMALLINT NOT NULL DEFAULT 2; ALTER TABLE syncapi_current_room_state ADD COLUMN IF NOT EXISTS history_visibility SMALLINT NOT NULL DEFAULT 2;
UPDATE syncapi_current_room_state SET history_visibility = 4 WHERE type IN ('m.room.message', 'm.room.encrypted'); UPDATE syncapi_current_room_state SET history_visibility = 4 WHERE type IN ('m.room.message', 'm.room.encrypted');
`) `)
@ -38,8 +42,8 @@ func UpAddHistoryVisibilityColumn(tx *sql.Tx) error {
return nil return nil
} }
func DownAddHistoryVisibilityColumn(tx *sql.Tx) error { func DownAddHistoryVisibilityColumn(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
ALTER TABLE syncapi_output_room_events DROP COLUMN IF EXISTS history_visibility; ALTER TABLE syncapi_output_room_events DROP COLUMN IF EXISTS history_visibility;
ALTER TABLE syncapi_current_room_state DROP COLUMN IF EXISTS history_visibility; ALTER TABLE syncapi_current_room_state DROP COLUMN IF EXISTS history_visibility;
`) `)

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
@ -188,6 +189,17 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "syncapi: add history visibility column (output_room_events)",
Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents,
})
err = m.Up(context.Background())
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertEventStmt, insertEventSQL}, {&s.insertEventStmt, insertEventSQL},
{&s.selectEventsStmt, selectEventsSQL}, {&s.selectEventsStmt, selectEventsSQL},

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -73,6 +74,15 @@ func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "syncapi: fix sequences",
Up: deltas.UpFixSequences,
})
err = m.Up(context.Background())
if err != nil {
return nil, err
}
r := &receiptStatements{ r := &receiptStatements{
db: db, db: db,
} }

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -76,6 +77,15 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "syncapi: drop sent_by_token",
Up: deltas.UpRemoveSendToDeviceSentColumn,
})
err = m.Up(context.Background())
if err != nil {
return nil, err
}
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
return nil, err return nil, err
} }

View file

@ -23,7 +23,6 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/shared" "github.com/matrix-org/dendrite/syncapi/storage/shared"
) )
@ -42,16 +41,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil {
return nil, err return nil, err
} }
if _, err = d.db.Exec(outputRoomEventsSchema); err != nil {
return nil, err
}
if _, err = d.db.Exec(currentRoomStateSchema); err != nil {
return nil, err
}
accountData, err := NewPostgresAccountDataTable(d.db) accountData, err := NewPostgresAccountDataTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
events, err := NewPostgresEventsTable(d.db)
if err != nil {
return nil, err
}
currState, err := NewPostgresCurrentRoomStateTable(d.db)
if err != nil {
return nil, err
}
invites, err := NewPostgresInvitesTable(d.db) invites, err := NewPostgresInvitesTable(d.db)
if err != nil { if err != nil {
return nil, err return nil, err
@ -96,22 +97,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrations()
deltas.LoadFixSequences(m)
deltas.LoadRemoveSendToDeviceSentColumn(m)
deltas.LoadAddHistoryVisibilityColumn(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil {
return nil, err
}
// prepare statements after the migrations have run
events, err := NewPostgresEventsTable(d.db)
if err != nil {
return nil, err
}
currState, err := NewPostgresCurrentRoomStateTable(d.db)
if err != nil {
return nil, err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Writer: d.writer, Writer: d.writer,

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -120,6 +121,17 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "syncapi: add history visibility column (current_room_state)",
Up: deltas.UpAddHistoryVisibilityColumnCurrentRoomState,
})
err = m.Up(context.Background())
if err != nil {
return nil, err
}
if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil { if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil {
return nil, err return nil, err
} }

View file

@ -15,24 +15,13 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
) )
func LoadFromGoose() { func UpFixSequences(ctx context.Context, tx *sql.Tx) error {
goose.AddMigration(UpFixSequences, DownFixSequences) _, err := tx.ExecContext(ctx, `
goose.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
}
func LoadFixSequences(m *sqlutil.Migrations) {
m.AddMigration(UpFixSequences, DownFixSequences)
}
func UpFixSequences(tx *sql.Tx) error {
_, err := tx.Exec(`
-- We need to delete all of the existing receipts because the indexes -- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to -- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence. -- reuse existing stream IDs from a different sequence.
@ -45,8 +34,8 @@ func UpFixSequences(tx *sql.Tx) error {
return nil return nil
} }
func DownFixSequences(tx *sql.Tx) error { func DownFixSequences(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
-- We need to delete all of the existing receipts because the indexes -- We need to delete all of the existing receipts because the indexes
-- will be wrong, and we'll get primary key violations if we try to -- will be wrong, and we'll get primary key violations if we try to
-- reuse existing stream IDs from a different sequence. -- reuse existing stream IDs from a different sequence.

View file

@ -15,18 +15,13 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
func LoadRemoveSendToDeviceSentColumn(m *sqlutil.Migrations) { func UpRemoveSendToDeviceSentColumn(ctx context.Context, tx *sql.Tx) error {
m.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn) _, err := tx.ExecContext(ctx, `
}
func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
CREATE TEMPORARY TABLE syncapi_send_to_device_backup(id, user_id, device_id, content); CREATE TEMPORARY TABLE syncapi_send_to_device_backup(id, user_id, device_id, content);
INSERT INTO syncapi_send_to_device_backup SELECT id, user_id, device_id, content FROM syncapi_send_to_device; INSERT INTO syncapi_send_to_device_backup SELECT id, user_id, device_id, content FROM syncapi_send_to_device;
DROP TABLE syncapi_send_to_device; DROP TABLE syncapi_send_to_device;
@ -45,8 +40,8 @@ func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
return nil return nil
} }
func DownRemoveSendToDeviceSentColumn(tx *sql.Tx) error { func DownRemoveSendToDeviceSentColumn(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
CREATE TEMPORARY TABLE syncapi_send_to_device_backup(id, user_id, device_id, content); CREATE TEMPORARY TABLE syncapi_send_to_device_backup(id, user_id, device_id, content);
INSERT INTO syncapi_send_to_device_backup SELECT id, user_id, device_id, content FROM syncapi_send_to_device; INSERT INTO syncapi_send_to_device_backup SELECT id, user_id, device_id, content FROM syncapi_send_to_device;
DROP TABLE syncapi_send_to_device; DROP TABLE syncapi_send_to_device;

View file

@ -15,36 +15,36 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
func LoadAddHistoryVisibilityColumn(m *sqlutil.Migrations) { func UpAddHistoryVisibilityColumnOutputRoomEvents(ctx context.Context, tx *sql.Tx) error {
m.AddMigration(UpAddHistoryVisibilityColumn, DownAddHistoryVisibilityColumn)
}
func UpAddHistoryVisibilityColumn(tx *sql.Tx) error {
// SQLite doesn't have "if exists", so check if the column exists. If the query doesn't return an error, it already exists. // SQLite doesn't have "if exists", so check if the column exists. If the query doesn't return an error, it already exists.
// Required for unit tests, as otherwise a duplicate column error will show up. // Required for unit tests, as otherwise a duplicate column error will show up.
_, err := tx.Query("SELECT history_visibility FROM syncapi_output_room_events LIMIT 1") _, err := tx.QueryContext(ctx, "SELECT history_visibility FROM syncapi_output_room_events LIMIT 1")
if err == nil { if err == nil {
return nil return nil
} }
_, err = tx.Exec(` _, err = tx.ExecContext(ctx, `
ALTER TABLE syncapi_output_room_events ADD COLUMN history_visibility SMALLINT NOT NULL DEFAULT 2; ALTER TABLE syncapi_output_room_events ADD COLUMN history_visibility SMALLINT NOT NULL DEFAULT 2;
UPDATE syncapi_output_room_events SET history_visibility = 4 WHERE type IN ('m.room.message', 'm.room.encrypted'); UPDATE syncapi_output_room_events SET history_visibility = 4 WHERE type IN ('m.room.message', 'm.room.encrypted');
`) `)
if err != nil { if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err) return fmt.Errorf("failed to execute upgrade: %w", err)
} }
return nil
}
_, err = tx.Query("SELECT history_visibility FROM syncapi_current_room_state LIMIT 1") func UpAddHistoryVisibilityColumnCurrentRoomState(ctx context.Context, tx *sql.Tx) error {
// SQLite doesn't have "if exists", so check if the column exists. If the query doesn't return an error, it already exists.
// Required for unit tests, as otherwise a duplicate column error will show up.
_, err := tx.QueryContext(ctx, "SELECT history_visibility FROM syncapi_current_room_state LIMIT 1")
if err == nil { if err == nil {
return nil return nil
} }
_, err = tx.Exec(` _, err = tx.ExecContext(ctx, `
ALTER TABLE syncapi_current_room_state ADD COLUMN history_visibility SMALLINT NOT NULL DEFAULT 2; ALTER TABLE syncapi_current_room_state ADD COLUMN history_visibility SMALLINT NOT NULL DEFAULT 2;
UPDATE syncapi_current_room_state SET history_visibility = 4 WHERE type IN ('m.room.message', 'm.room.encrypted'); UPDATE syncapi_current_room_state SET history_visibility = 4 WHERE type IN ('m.room.message', 'm.room.encrypted');
`) `)
@ -54,25 +54,25 @@ func UpAddHistoryVisibilityColumn(tx *sql.Tx) error {
return nil return nil
} }
func DownAddHistoryVisibilityColumn(tx *sql.Tx) error { func DownAddHistoryVisibilityColumn(ctx context.Context, tx *sql.Tx) error {
// SQLite doesn't have "if exists", so check if the column exists. // SQLite doesn't have "if exists", so check if the column exists.
_, err := tx.Query("SELECT history_visibility FROM syncapi_output_room_events LIMIT 1") _, err := tx.QueryContext(ctx, "SELECT history_visibility FROM syncapi_output_room_events LIMIT 1")
if err != nil { if err != nil {
// The column probably doesn't exist // The column probably doesn't exist
return nil return nil
} }
_, err = tx.Exec(` _, err = tx.ExecContext(ctx, `
ALTER TABLE syncapi_output_room_events DROP COLUMN history_visibility; ALTER TABLE syncapi_output_room_events DROP COLUMN history_visibility;
`) `)
if err != nil { if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err) return fmt.Errorf("failed to execute downgrade: %w", err)
} }
_, err = tx.Query("SELECT history_visibility FROM syncapi_current_room_state LIMIT 1") _, err = tx.QueryContext(ctx, "SELECT history_visibility FROM syncapi_current_room_state LIMIT 1")
if err != nil { if err != nil {
// The column probably doesn't exist // The column probably doesn't exist
return nil return nil
} }
_, err = tx.Exec(` _, err = tx.ExecContext(ctx, `
ALTER TABLE syncapi_current_room_state DROP COLUMN history_visibility; ALTER TABLE syncapi_current_room_state DROP COLUMN history_visibility;
`) `)
if err != nil { if err != nil {

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
@ -136,6 +137,17 @@ func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Even
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "syncapi: add history visibility column (output_room_events)",
Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents,
})
err = m.Up(context.Background())
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertEventStmt, insertEventSQL}, {&s.insertEventStmt, insertEventSQL},
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL}, {&s.selectMaxEventIDStmt, selectMaxEventIDSQL},

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -70,6 +71,15 @@ func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Re
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "syncapi: fix sequences",
Up: deltas.UpFixSequences,
})
err = m.Up(context.Background())
if err != nil {
return nil, err
}
r := &receiptStatements{ r := &receiptStatements{
db: db, db: db,
streamIDStatements: streamID, streamIDStatements: streamID,

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -76,6 +77,15 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "syncapi: drop sent_by_token",
Up: deltas.UpRemoveSendToDeviceSentColumn,
})
err = m.Up(context.Background())
if err != nil {
return nil, err
}
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
return nil, err return nil, err
} }

View file

@ -22,7 +22,6 @@ import (
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/syncapi/storage/shared" "github.com/matrix-org/dendrite/syncapi/storage/shared"
"github.com/matrix-org/dendrite/syncapi/storage/sqlite3/deltas"
) )
// SyncServerDatasource represents a sync server datasource which manages // SyncServerDatasource represents a sync server datasource which manages
@ -42,26 +41,28 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions)
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
return nil, err return nil, err
} }
if err = d.prepare(dbProperties); err != nil { if err = d.prepare(); err != nil {
return nil, err return nil, err
} }
return &d, nil return &d, nil
} }
func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) { func (d *SyncServerDatasource) prepare() (err error) {
if err = d.streamID.Prepare(d.db); err != nil { if err = d.streamID.Prepare(d.db); err != nil {
return err return err
} }
if _, err = d.db.Exec(outputRoomEventsSchema); err != nil {
return err
}
if _, err = d.db.Exec(currentRoomStateSchema); err != nil {
return err
}
accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID) accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID)
if err != nil { if err != nil {
return err return err
} }
events, err := NewSqliteEventsTable(d.db, &d.streamID)
if err != nil {
return err
}
roomState, err := NewSqliteCurrentRoomStateTable(d.db, &d.streamID)
if err != nil {
return err
}
invites, err := NewSqliteInvitesTable(d.db, &d.streamID) invites, err := NewSqliteInvitesTable(d.db, &d.streamID)
if err != nil { if err != nil {
return err return err
@ -106,22 +107,6 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er
if err != nil { if err != nil {
return err return err
} }
m := sqlutil.NewMigrations()
deltas.LoadFixSequences(m)
deltas.LoadRemoveSendToDeviceSentColumn(m)
deltas.LoadAddHistoryVisibilityColumn(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil {
return err
}
// prepare statements after the migrations have run
events, err := NewSqliteEventsTable(d.db, &d.streamID)
if err != nil {
return err
}
roomState, err := NewSqliteCurrentRoomStateTable(d.db, &d.streamID)
if err != nil {
return err
}
d.Database = shared.Database{ d.Database = shared.Database{
DB: d.db, DB: d.db,
Writer: d.writer, Writer: d.writer,

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -85,6 +86,23 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations([]sqlutil.Migration{
{
Version: "userapi: add is active",
Up: deltas.UpIsActive,
Down: deltas.DownIsActive,
},
{
Version: "userapi: add account type",
Up: deltas.UpAddAccountType,
Down: deltas.DownAddAccountType,
},
}...)
err = m.Up(context.Background())
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertAccountStmt, insertAccountSQL}, {&s.insertAccountStmt, insertAccountSQL},
{&s.updatePasswordStmt, updatePasswordSQL}, {&s.updatePasswordStmt, updatePasswordSQL},

View file

@ -1,33 +1,21 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/pressly/goose"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
func LoadFromGoose() { func UpIsActive(ctx context.Context, tx *sql.Tx) error {
goose.AddMigration(UpIsActive, DownIsActive) _, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS is_deactivated BOOLEAN DEFAULT FALSE;")
goose.AddMigration(UpAddAccountType, DownAddAccountType)
}
func LoadIsActive(m *sqlutil.Migrations) {
m.AddMigration(UpIsActive, DownIsActive)
}
func UpIsActive(tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS is_deactivated BOOLEAN DEFAULT FALSE;")
if err != nil { if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err) return fmt.Errorf("failed to execute upgrade: %w", err)
} }
return nil return nil
} }
func DownIsActive(tx *sql.Tx) error { func DownIsActive(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN is_deactivated;") _, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts DROP COLUMN is_deactivated;")
if err != nil { if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err) return fmt.Errorf("failed to execute downgrade: %w", err)
} }

View file

@ -1,18 +1,13 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
func LoadLastSeenTSIP(m *sqlutil.Migrations) { func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) _, err := tx.ExecContext(ctx, `
}
func UpLastSeenTSIP(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS last_seen_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM CURRENT_TIMESTAMP)*1000; ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS last_seen_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM CURRENT_TIMESTAMP)*1000;
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS ip TEXT; ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS ip TEXT;
ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`) ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`)
@ -22,8 +17,8 @@ ALTER TABLE device_devices ADD COLUMN IF NOT EXISTS user_agent TEXT;`)
return nil return nil
} }
func DownLastSeenTSIP(tx *sql.Tx) error { func DownLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
ALTER TABLE device_devices DROP COLUMN last_seen_ts; ALTER TABLE device_devices DROP COLUMN last_seen_ts;
ALTER TABLE device_devices DROP COLUMN ip; ALTER TABLE device_devices DROP COLUMN ip;
ALTER TABLE device_devices DROP COLUMN user_agent;`) ALTER TABLE device_devices DROP COLUMN user_agent;`)

View file

@ -1,20 +1,15 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
func LoadAddAccountType(m *sqlutil.Migrations) { func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
m.AddMigration(UpAddAccountType, DownAddAccountType)
}
func UpAddAccountType(tx *sql.Tx) error {
// initially set every account to useraccount, change appservice and guest accounts afterwards // initially set every account to useraccount, change appservice and guest accounts afterwards
// (user = 1, guest = 2, admin = 3, appservice = 4) // (user = 1, guest = 2, admin = 3, appservice = 4)
_, err := tx.Exec(`ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1; _, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts ADD COLUMN IF NOT EXISTS account_type SMALLINT NOT NULL DEFAULT 1;
UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> ''; UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> '';
UPDATE account_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$'; UPDATE account_accounts SET account_type = 2 WHERE localpart ~ '^[0-9]+$';
ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`, ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`,
@ -25,8 +20,8 @@ ALTER TABLE account_accounts ALTER COLUMN account_type DROP DEFAULT;`,
return nil return nil
} }
func DownAddAccountType(tx *sql.Tx) error { func DownAddAccountType(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec("ALTER TABLE account_accounts DROP COLUMN account_type;") _, err := tx.ExecContext(ctx, "ALTER TABLE account_accounts DROP COLUMN account_type;")
if err != nil { if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err) return fmt.Errorf("failed to execute downgrade: %w", err)
} }

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -120,6 +121,15 @@ func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "userapi: add last_seen_ts",
Up: deltas.UpLastSeenTSIP,
})
err = m.Up(context.Background())
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertDeviceStmt, insertDeviceSQL}, {&s.insertDeviceStmt, insertDeviceSQL},
{&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL}, {&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL},

View file

@ -23,7 +23,6 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/userapi/storage/shared" "github.com/matrix-org/dendrite/userapi/storage/shared"
// Import the postgres database driver. // Import the postgres database driver.
@ -37,19 +36,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
return nil, err return nil, err
} }
m := sqlutil.NewMigrations()
if _, err = db.Exec(accountsSchema); err != nil {
// do this so that the migration can and we don't fail on
// preparing statements for columns that don't exist yet
return nil, err
}
deltas.LoadIsActive(m)
//deltas.LoadLastSeenTSIP(m)
deltas.LoadAddAccountType(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
accountDataTable, err := NewPostgresAccountDataTable(db) accountDataTable, err := NewPostgresAccountDataTable(db)
if err != nil { if err != nil {
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err) return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -87,6 +88,23 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations([]sqlutil.Migration{
{
Version: "userapi: add is active",
Up: deltas.UpIsActive,
Down: deltas.DownIsActive,
},
{
Version: "userapi: add account type",
Up: deltas.UpAddAccountType,
Down: deltas.DownAddAccountType,
},
}...)
err = m.Up(context.Background())
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertAccountStmt, insertAccountSQL}, {&s.insertAccountStmt, insertAccountSQL},
{&s.updatePasswordStmt, updatePasswordSQL}, {&s.updatePasswordStmt, updatePasswordSQL},

View file

@ -1,25 +1,13 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/pressly/goose"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
func LoadFromGoose() { func UpIsActive(ctx context.Context, tx *sql.Tx) error {
goose.AddMigration(UpIsActive, DownIsActive) _, err := tx.ExecContext(ctx, `
goose.AddMigration(UpAddAccountType, DownAddAccountType)
}
func LoadIsActive(m *sqlutil.Migrations) {
m.AddMigration(UpIsActive, DownIsActive)
}
func UpIsActive(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE account_accounts RENAME TO account_accounts_tmp; ALTER TABLE account_accounts RENAME TO account_accounts_tmp;
CREATE TABLE account_accounts ( CREATE TABLE account_accounts (
localpart TEXT NOT NULL PRIMARY KEY, localpart TEXT NOT NULL PRIMARY KEY,
@ -42,8 +30,8 @@ DROP TABLE account_accounts_tmp;`)
return nil return nil
} }
func DownIsActive(tx *sql.Tx) error { func DownIsActive(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
ALTER TABLE account_accounts RENAME TO account_accounts_tmp; ALTER TABLE account_accounts RENAME TO account_accounts_tmp;
CREATE TABLE account_accounts ( CREATE TABLE account_accounts (
localpart TEXT NOT NULL PRIMARY KEY, localpart TEXT NOT NULL PRIMARY KEY,

View file

@ -1,18 +1,13 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
func LoadLastSeenTSIP(m *sqlutil.Migrations) { func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP) _, err := tx.ExecContext(ctx, `
}
func UpLastSeenTSIP(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE device_devices RENAME TO device_devices_tmp; ALTER TABLE device_devices RENAME TO device_devices_tmp;
CREATE TABLE device_devices ( CREATE TABLE device_devices (
access_token TEXT PRIMARY KEY, access_token TEXT PRIMARY KEY,
@ -39,8 +34,8 @@ func UpLastSeenTSIP(tx *sql.Tx) error {
return nil return nil
} }
func DownLastSeenTSIP(tx *sql.Tx) error { func DownLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(` _, err := tx.ExecContext(ctx, `
ALTER TABLE device_devices RENAME TO device_devices_tmp; ALTER TABLE device_devices RENAME TO device_devices_tmp;
CREATE TABLE IF NOT EXISTS device_devices ( CREATE TABLE IF NOT EXISTS device_devices (
access_token TEXT PRIMARY KEY, access_token TEXT PRIMARY KEY,

View file

@ -1,26 +1,15 @@
package deltas package deltas
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/pressly/goose"
"github.com/matrix-org/dendrite/internal/sqlutil"
) )
func init() { func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
goose.AddMigration(UpAddAccountType, DownAddAccountType)
}
func LoadAddAccountType(m *sqlutil.Migrations) {
m.AddMigration(UpAddAccountType, DownAddAccountType)
}
func UpAddAccountType(tx *sql.Tx) error {
// initially set every account to useraccount, change appservice and guest accounts afterwards // initially set every account to useraccount, change appservice and guest accounts afterwards
// (user = 1, guest = 2, admin = 3, appservice = 4) // (user = 1, guest = 2, admin = 3, appservice = 4)
_, err := tx.Exec(`ALTER TABLE account_accounts RENAME TO account_accounts_tmp; _, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts RENAME TO account_accounts_tmp;
CREATE TABLE account_accounts ( CREATE TABLE account_accounts (
localpart TEXT NOT NULL PRIMARY KEY, localpart TEXT NOT NULL PRIMARY KEY,
created_ts BIGINT NOT NULL, created_ts BIGINT NOT NULL,
@ -45,8 +34,8 @@ DROP TABLE account_accounts_tmp;`)
return nil return nil
} }
func DownAddAccountType(tx *sql.Tx) error { func DownAddAccountType(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(`ALTER TABLE account_accounts DROP COLUMN account_type;`) _, err := tx.ExecContext(ctx, `ALTER TABLE account_accounts DROP COLUMN account_type;`)
if err != nil { if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err) return fmt.Errorf("failed to execute downgrade: %w", err)
} }

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
@ -107,6 +108,15 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "userapi: add last_seen_ts",
Up: deltas.UpLastSeenTSIP,
})
if err = m.Up(context.Background()); err != nil {
return nil, err
}
return s, sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertDeviceStmt, insertDeviceSQL}, {&s.insertDeviceStmt, insertDeviceSQL},
{&s.selectDevicesCountStmt, selectDevicesCountSQL}, {&s.selectDevicesCountStmt, selectDevicesCountSQL},

View file

@ -25,10 +25,6 @@ import (
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/shared" "github.com/matrix-org/dendrite/userapi/storage/shared"
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
// Import the postgres database driver.
_ "github.com/lib/pq"
) )
// NewDatabase creates a new accounts and profiles database // NewDatabase creates a new accounts and profiles database
@ -38,19 +34,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
return nil, err return nil, err
} }
m := sqlutil.NewMigrations()
if _, err = db.Exec(accountsSchema); err != nil {
// do this so that the migration can and we don't fail on
// preparing statements for columns that don't exist yet
return nil, err
}
deltas.LoadIsActive(m)
//deltas.LoadLastSeenTSIP(m)
deltas.LoadAddAccountType(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
}
accountDataTable, err := NewSQLiteAccountDataTable(db) accountDataTable, err := NewSQLiteAccountDataTable(db)
if err != nil { if err != nil {
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err) return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)