diff --git a/common/basecomponent/base.go b/common/basecomponent/base.go index b05ec42d..50fc2d5c 100644 --- a/common/basecomponent/base.go +++ b/common/basecomponent/base.go @@ -138,7 +138,7 @@ func (b *BaseDendrite) CreateAccountsDB() *accounts.Database { // CreateKeyDB creates a new instance of the key database. Should only be called // once per component. -func (b *BaseDendrite) CreateKeyDB() *keydb.Database { +func (b *BaseDendrite) CreateKeyDB() keydb.Database { db, err := keydb.NewDatabase(string(b.Cfg.Database.ServerKey)) if err != nil { logrus.WithError(err).Panicf("failed to connect to keys db") diff --git a/common/keydb/keydb.go b/common/keydb/keydb.go index 9e8b6a6f..b9fa884e 100644 --- a/common/keydb/keydb.go +++ b/common/keydb/keydb.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,67 +16,29 @@ package keydb import ( "context" - "database/sql" + "errors" + "net/url" + "github.com/matrix-org/dendrite/common/keydb/postgres" "github.com/matrix-org/gomatrixserverlib" ) -// A Database implements gomatrixserverlib.KeyDatabase and is used to store -// the public keys for other matrix servers. -type Database struct { - statements serverKeyStatements +type Database interface { + FetcherName() string + FetchKeys(ctx context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) + StoreKeys(ctx context.Context, keyMap map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) error } -// NewDatabase prepares a new key database. -// It creates the necessary tables if they don't already exist. -// It prepares all the SQL statements that it will use. -// Returns an error if there was a problem talking to the database. -func NewDatabase(dataSourceName string) (*Database, error) { - db, err := sql.Open("postgres", dataSourceName) +// NewDatabase opens a database connection. +func NewDatabase(dataSourceName string) (Database, error) { + uri, err := url.Parse(dataSourceName) if err != nil { return nil, err } - d := &Database{} - err = d.statements.prepare(db) - if err != nil { - return nil, err + switch uri.Scheme { + case "postgres": + return postgres.NewDatabase(dataSourceName) + default: + return nil, errors.New("unknown schema") } - return d, nil -} - -// FetcherName implements KeyFetcher -func (d Database) FetcherName() string { - return "KeyDatabase" -} - -// FetchKeys implements gomatrixserverlib.KeyDatabase -func (d *Database) FetchKeys( - ctx context.Context, - requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, -) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { - return d.statements.bulkSelectServerKeys(ctx, requests) -} - -// StoreKeys implements gomatrixserverlib.KeyDatabase -func (d *Database) StoreKeys( - ctx context.Context, - keyMap map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, -) error { - // TODO: Inserting all the keys within a single transaction may - // be more efficient since the transaction overhead can be quite - // high for a single insert statement. - var lastErr error - for request, keys := range keyMap { - if err := d.statements.upsertServerKeys(ctx, request, keys); err != nil { - // Rather than returning immediately on error we try to insert the - // remaining keys. - // Since we are inserting the keys outside of a transaction it is - // possible for some of the inserts to succeed even though some - // of the inserts have failed. - // Ensuring that we always insert all the keys we can means that - // this behaviour won't depend on the iteration order of the map. - lastErr = err - } - } - return lastErr } diff --git a/common/keydb/postgres/keydb.go b/common/keydb/postgres/keydb.go new file mode 100644 index 00000000..bf0ff69c --- /dev/null +++ b/common/keydb/postgres/keydb.go @@ -0,0 +1,83 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/gomatrixserverlib" +) + +// A Database implements gomatrixserverlib.KeyDatabase and is used to store +// the public keys for other matrix servers. +type Database struct { + statements serverKeyStatements +} + +// NewDatabase prepares a new key database. +// It creates the necessary tables if they don't already exist. +// It prepares all the SQL statements that it will use. +// Returns an error if there was a problem talking to the database. +func NewDatabase(dataSourceName string) (*Database, error) { + db, err := sql.Open("postgres", dataSourceName) + if err != nil { + return nil, err + } + d := &Database{} + err = d.statements.prepare(db) + if err != nil { + return nil, err + } + return d, nil +} + +// FetcherName implements KeyFetcher +func (d Database) FetcherName() string { + return "KeyDatabase" +} + +// FetchKeys implements gomatrixserverlib.KeyDatabase +func (d *Database) FetchKeys( + ctx context.Context, + requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, +) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { + return d.statements.bulkSelectServerKeys(ctx, requests) +} + +// StoreKeys implements gomatrixserverlib.KeyDatabase +func (d *Database) StoreKeys( + ctx context.Context, + keyMap map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, +) error { + // TODO: Inserting all the keys within a single transaction may + // be more efficient since the transaction overhead can be quite + // high for a single insert statement. + var lastErr error + for request, keys := range keyMap { + if err := d.statements.upsertServerKeys(ctx, request, keys); err != nil { + // Rather than returning immediately on error we try to insert the + // remaining keys. + // Since we are inserting the keys outside of a transaction it is + // possible for some of the inserts to succeed even though some + // of the inserts have failed. + // Ensuring that we always insert all the keys we can means that + // this behaviour won't depend on the iteration order of the map. + lastErr = err + } + } + return lastErr +} diff --git a/common/keydb/server_key_table.go b/common/keydb/postgres/server_key_table.go similarity index 97% rename from common/keydb/server_key_table.go rename to common/keydb/postgres/server_key_table.go index c922fa98..8fb9a0ee 100644 --- a/common/keydb/server_key_table.go +++ b/common/keydb/postgres/server_key_table.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package keydb +package postgres import ( "context" diff --git a/federationsender/consumers/roomserver.go b/federationsender/consumers/roomserver.go index 3ba978b1..4568f44d 100644 --- a/federationsender/consumers/roomserver.go +++ b/federationsender/consumers/roomserver.go @@ -33,7 +33,7 @@ import ( // OutputRoomEventConsumer consumes events that originated in the room server. type OutputRoomEventConsumer struct { roomServerConsumer *common.ContinualConsumer - db *storage.Database + db storage.Database queues *queue.OutgoingQueues query api.RoomserverQueryAPI } @@ -43,7 +43,7 @@ func NewOutputRoomEventConsumer( cfg *config.Dendrite, kafkaConsumer sarama.Consumer, queues *queue.OutgoingQueues, - store *storage.Database, + store storage.Database, queryAPI api.RoomserverQueryAPI, ) *OutputRoomEventConsumer { consumer := common.ContinualConsumer{ diff --git a/federationsender/consumers/typingserver.go b/federationsender/consumers/typingserver.go index c4cd0e59..590fcb25 100644 --- a/federationsender/consumers/typingserver.go +++ b/federationsender/consumers/typingserver.go @@ -29,7 +29,7 @@ import ( // OutputTypingEventConsumer consumes events that originate in typing server. type OutputTypingEventConsumer struct { consumer *common.ContinualConsumer - db *storage.Database + db storage.Database queues *queue.OutgoingQueues ServerName gomatrixserverlib.ServerName } @@ -39,7 +39,7 @@ func NewOutputTypingEventConsumer( cfg *config.Dendrite, kafkaConsumer sarama.Consumer, queues *queue.OutgoingQueues, - store *storage.Database, + store storage.Database, ) *OutputTypingEventConsumer { consumer := common.ContinualConsumer{ Topic: string(cfg.Kafka.Topics.OutputTypingEvent), diff --git a/federationsender/storage/joined_hosts_table.go b/federationsender/storage/postgres/joined_hosts_table.go similarity index 97% rename from federationsender/storage/joined_hosts_table.go rename to federationsender/storage/postgres/joined_hosts_table.go index 5d652a1a..bd580e3b 100644 --- a/federationsender/storage/joined_hosts_table.go +++ b/federationsender/storage/postgres/joined_hosts_table.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "context" diff --git a/federationsender/storage/room_table.go b/federationsender/storage/postgres/room_table.go similarity index 96% rename from federationsender/storage/room_table.go rename to federationsender/storage/postgres/room_table.go index bb52b707..a64424b4 100644 --- a/federationsender/storage/room_table.go +++ b/federationsender/storage/postgres/room_table.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "context" diff --git a/federationsender/storage/postgres/storage.go b/federationsender/storage/postgres/storage.go new file mode 100644 index 00000000..c60f6dc5 --- /dev/null +++ b/federationsender/storage/postgres/storage.go @@ -0,0 +1,122 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/federationsender/types" +) + +// Database stores information needed by the federation sender +type Database struct { + joinedHostsStatements + roomStatements + common.PartitionOffsetStatements + db *sql.DB +} + +// NewDatabase opens a new database +func NewDatabase(dataSourceName string) (*Database, error) { + var result Database + var err error + if result.db, err = sql.Open("postgres", dataSourceName); err != nil { + return nil, err + } + if err = result.prepare(); err != nil { + return nil, err + } + return &result, nil +} + +func (d *Database) prepare() error { + var err error + + if err = d.joinedHostsStatements.prepare(d.db); err != nil { + return err + } + + if err = d.roomStatements.prepare(d.db); err != nil { + return err + } + + return d.PartitionOffsetStatements.Prepare(d.db, "federationsender") +} + +// UpdateRoom updates the joined hosts for a room and returns what the joined +// hosts were before the update, or nil if this was a duplicate message. +// This is called when we receive a message from kafka, so we pass in +// oldEventID and newEventID to check that we haven't missed any messages or +// this isn't a duplicate message. +func (d *Database) UpdateRoom( + ctx context.Context, + roomID, oldEventID, newEventID string, + addHosts []types.JoinedHost, + removeHosts []string, +) (joinedHosts []types.JoinedHost, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + err = d.insertRoom(ctx, txn, roomID) + if err != nil { + return err + } + + lastSentEventID, err := d.selectRoomForUpdate(ctx, txn, roomID) + if err != nil { + return err + } + + if lastSentEventID == newEventID { + // We've handled this message before, so let's just ignore it. + // We can only get a duplicate for the last message we processed, + // so its enough just to compare the newEventID with lastSentEventID + return nil + } + + if lastSentEventID != oldEventID { + return types.EventIDMismatchError{ + DatabaseID: lastSentEventID, RoomServerID: oldEventID, + } + } + + joinedHosts, err = d.selectJoinedHostsWithTx(ctx, txn, roomID) + if err != nil { + return err + } + + for _, add := range addHosts { + err = d.insertJoinedHosts(ctx, txn, roomID, add.MemberEventID, add.ServerName) + if err != nil { + return err + } + } + if err = d.deleteJoinedHosts(ctx, txn, removeHosts); err != nil { + return err + } + return d.updateRoom(ctx, txn, roomID, newEventID) + }) + return +} + +// GetJoinedHosts returns the currently joined hosts for room, +// as known to federationserver. +// Returns an error if something goes wrong. +func (d *Database) GetJoinedHosts( + ctx context.Context, roomID string, +) ([]types.JoinedHost, error) { + return d.selectJoinedHosts(ctx, roomID) +} diff --git a/federationsender/storage/storage.go b/federationsender/storage/storage.go index 3a0f8775..8cffdbf1 100644 --- a/federationsender/storage/storage.go +++ b/federationsender/storage/storage.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,106 +16,30 @@ package storage import ( "context" - "database/sql" + "errors" + "net/url" "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/federationsender/storage/postgres" "github.com/matrix-org/dendrite/federationsender/types" ) -// Database stores information needed by the federation sender -type Database struct { - joinedHostsStatements - roomStatements - common.PartitionOffsetStatements - db *sql.DB +type Database interface { + common.PartitionStorer + UpdateRoom(ctx context.Context, roomID, oldEventID, newEventID string, addHosts []types.JoinedHost, removeHosts []string) (joinedHosts []types.JoinedHost, err error) + GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) } // NewDatabase opens a new database -func NewDatabase(dataSourceName string) (*Database, error) { - var result Database - var err error - if result.db, err = sql.Open("postgres", dataSourceName); err != nil { +func NewDatabase(dataSourceName string) (Database, error) { + uri, err := url.Parse(dataSourceName) + if err != nil { return nil, err } - if err = result.prepare(); err != nil { - return nil, err + switch uri.Scheme { + case "postgres": + return postgres.NewDatabase(dataSourceName) + default: + return nil, errors.New("unknown schema") } - return &result, nil -} - -func (d *Database) prepare() error { - var err error - - if err = d.joinedHostsStatements.prepare(d.db); err != nil { - return err - } - - if err = d.roomStatements.prepare(d.db); err != nil { - return err - } - - return d.PartitionOffsetStatements.Prepare(d.db, "federationsender") -} - -// UpdateRoom updates the joined hosts for a room and returns what the joined -// hosts were before the update, or nil if this was a duplicate message. -// This is called when we receive a message from kafka, so we pass in -// oldEventID and newEventID to check that we haven't missed any messages or -// this isn't a duplicate message. -func (d *Database) UpdateRoom( - ctx context.Context, - roomID, oldEventID, newEventID string, - addHosts []types.JoinedHost, - removeHosts []string, -) (joinedHosts []types.JoinedHost, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - err = d.insertRoom(ctx, txn, roomID) - if err != nil { - return err - } - - lastSentEventID, err := d.selectRoomForUpdate(ctx, txn, roomID) - if err != nil { - return err - } - - if lastSentEventID == newEventID { - // We've handled this message before, so let's just ignore it. - // We can only get a duplicate for the last message we processed, - // so its enough just to compare the newEventID with lastSentEventID - return nil - } - - if lastSentEventID != oldEventID { - return types.EventIDMismatchError{ - DatabaseID: lastSentEventID, RoomServerID: oldEventID, - } - } - - joinedHosts, err = d.selectJoinedHostsWithTx(ctx, txn, roomID) - if err != nil { - return err - } - - for _, add := range addHosts { - err = d.insertJoinedHosts(ctx, txn, roomID, add.MemberEventID, add.ServerName) - if err != nil { - return err - } - } - if err = d.deleteJoinedHosts(ctx, txn, removeHosts); err != nil { - return err - } - return d.updateRoom(ctx, txn, roomID, newEventID) - }) - return -} - -// GetJoinedHosts returns the currently joined hosts for room, -// as known to federationserver. -// Returns an error if something goes wrong. -func (d *Database) GetJoinedHosts( - ctx context.Context, roomID string, -) ([]types.JoinedHost, error) { - return d.selectJoinedHosts(ctx, roomID) } diff --git a/mediaapi/routing/download.go b/mediaapi/routing/download.go index 80ad8418..8544bd64 100644 --- a/mediaapi/routing/download.go +++ b/mediaapi/routing/download.go @@ -67,7 +67,7 @@ func Download( origin gomatrixserverlib.ServerName, mediaID types.MediaID, cfg *config.Dendrite, - db *storage.Database, + db storage.Database, client *gomatrixserverlib.Client, activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration, @@ -192,7 +192,7 @@ func (r *downloadRequest) doDownload( ctx context.Context, w http.ResponseWriter, cfg *config.Dendrite, - db *storage.Database, + db storage.Database, client *gomatrixserverlib.Client, activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration, @@ -235,7 +235,7 @@ func (r *downloadRequest) respondFromLocalFile( absBasePath config.Path, activeThumbnailGeneration *types.ActiveThumbnailGeneration, maxThumbnailGenerators int, - db *storage.Database, + db storage.Database, dynamicThumbnails bool, thumbnailSizes []config.ThumbnailSize, ) (*types.MediaMetadata, error) { @@ -325,7 +325,7 @@ func (r *downloadRequest) getThumbnailFile( filePath types.Path, activeThumbnailGeneration *types.ActiveThumbnailGeneration, maxThumbnailGenerators int, - db *storage.Database, + db storage.Database, dynamicThumbnails bool, thumbnailSizes []config.ThumbnailSize, ) (*os.File, *types.ThumbnailMetadata, error) { @@ -407,7 +407,7 @@ func (r *downloadRequest) generateThumbnail( thumbnailSize types.ThumbnailSize, activeThumbnailGeneration *types.ActiveThumbnailGeneration, maxThumbnailGenerators int, - db *storage.Database, + db storage.Database, ) (*types.ThumbnailMetadata, error) { r.Logger.WithFields(log.Fields{ "Width": thumbnailSize.Width, @@ -443,7 +443,7 @@ func (r *downloadRequest) getRemoteFile( ctx context.Context, client *gomatrixserverlib.Client, cfg *config.Dendrite, - db *storage.Database, + db storage.Database, activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration, ) (errorResponse error) { @@ -545,7 +545,7 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata( client *gomatrixserverlib.Client, absBasePath config.Path, maxFileSizeBytes config.FileSizeBytes, - db *storage.Database, + db storage.Database, thumbnailSizes []config.ThumbnailSize, activeThumbnailGeneration *types.ActiveThumbnailGeneration, maxThumbnailGenerators int, diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go index 0f226664..dcc6ac06 100644 --- a/mediaapi/routing/routing.go +++ b/mediaapi/routing/routing.go @@ -43,7 +43,7 @@ const pathPrefixR0 = "/_matrix/media/r0" func Setup( apiMux *mux.Router, cfg *config.Dendrite, - db *storage.Database, + db storage.Database, deviceDB *devices.Database, client *gomatrixserverlib.Client, ) { @@ -80,7 +80,7 @@ func Setup( func makeDownloadAPI( name string, cfg *config.Dendrite, - db *storage.Database, + db storage.Database, client *gomatrixserverlib.Client, activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration, diff --git a/mediaapi/routing/upload.go b/mediaapi/routing/upload.go index 2cb0d875..91a45319 100644 --- a/mediaapi/routing/upload.go +++ b/mediaapi/routing/upload.go @@ -53,7 +53,7 @@ type uploadResponse struct { // This implementation supports a configurable maximum file size limit in bytes. If a user tries to upload more than this, they will receive an error that their upload is too large. // Uploaded files are processed piece-wise to avoid DoS attacks which would starve the server of memory. // TODO: We should time out requests if they have not received any data within a configured timeout period. -func Upload(req *http.Request, cfg *config.Dendrite, db *storage.Database, activeThumbnailGeneration *types.ActiveThumbnailGeneration) util.JSONResponse { +func Upload(req *http.Request, cfg *config.Dendrite, db storage.Database, activeThumbnailGeneration *types.ActiveThumbnailGeneration) util.JSONResponse { r, resErr := parseAndValidateRequest(req, cfg) if resErr != nil { return *resErr @@ -96,7 +96,7 @@ func (r *uploadRequest) doUpload( ctx context.Context, reqReader io.Reader, cfg *config.Dendrite, - db *storage.Database, + db storage.Database, activeThumbnailGeneration *types.ActiveThumbnailGeneration, ) *util.JSONResponse { r.Logger.WithFields(log.Fields{ @@ -214,7 +214,7 @@ func (r *uploadRequest) storeFileAndMetadata( ctx context.Context, tmpDir types.Path, absBasePath config.Path, - db *storage.Database, + db storage.Database, thumbnailSizes []config.ThumbnailSize, activeThumbnailGeneration *types.ActiveThumbnailGeneration, maxThumbnailGenerators int, diff --git a/mediaapi/storage/media_repository_table.go b/mediaapi/storage/postgres/media_repository_table.go similarity index 97% rename from mediaapi/storage/media_repository_table.go rename to mediaapi/storage/postgres/media_repository_table.go index addd47b4..e975530a 100644 --- a/mediaapi/storage/media_repository_table.go +++ b/mediaapi/storage/postgres/media_repository_table.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "context" diff --git a/mediaapi/storage/prepare.go b/mediaapi/storage/postgres/prepare.go similarity index 90% rename from mediaapi/storage/prepare.go rename to mediaapi/storage/postgres/prepare.go index a30586de..090c3d17 100644 --- a/mediaapi/storage/prepare.go +++ b/mediaapi/storage/postgres/prepare.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,7 +15,7 @@ // FIXME: This should be made common! -package storage +package postgres import ( "database/sql" diff --git a/mediaapi/storage/sql.go b/mediaapi/storage/postgres/sql.go similarity index 88% rename from mediaapi/storage/sql.go rename to mediaapi/storage/postgres/sql.go index 1f8c7be3..181cd15f 100644 --- a/mediaapi/storage/sql.go +++ b/mediaapi/storage/postgres/sql.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "database/sql" diff --git a/mediaapi/storage/postgres/storage.go b/mediaapi/storage/postgres/storage.go new file mode 100644 index 00000000..6259f4a1 --- /dev/null +++ b/mediaapi/storage/postgres/storage.go @@ -0,0 +1,106 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + // Import the postgres database driver. + _ "github.com/lib/pq" + "github.com/matrix-org/dendrite/mediaapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +// Database is used to store metadata about a repository of media files. +type Database struct { + statements statements + db *sql.DB +} + +// Open opens a postgres database. +func Open(dataSourceName string) (*Database, error) { + var d Database + var err error + if d.db, err = sql.Open("postgres", dataSourceName); err != nil { + return nil, err + } + if err = d.statements.prepare(d.db); err != nil { + return nil, err + } + return &d, nil +} + +// StoreMediaMetadata inserts the metadata about the uploaded media into the database. +// Returns an error if the combination of MediaID and Origin are not unique in the table. +func (d *Database) StoreMediaMetadata( + ctx context.Context, mediaMetadata *types.MediaMetadata, +) error { + return d.statements.media.insertMedia(ctx, mediaMetadata) +} + +// GetMediaMetadata returns metadata about media stored on this server. +// The media could have been uploaded to this server or fetched from another server and cached here. +// Returns nil metadata if there is no metadata associated with this media. +func (d *Database) GetMediaMetadata( + ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, +) (*types.MediaMetadata, error) { + mediaMetadata, err := d.statements.media.selectMedia(ctx, mediaID, mediaOrigin) + if err != nil && err == sql.ErrNoRows { + return nil, nil + } + return mediaMetadata, err +} + +// StoreThumbnail inserts the metadata about the thumbnail into the database. +// Returns an error if the combination of MediaID and Origin are not unique in the table. +func (d *Database) StoreThumbnail( + ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata, +) error { + return d.statements.thumbnail.insertThumbnail(ctx, thumbnailMetadata) +} + +// GetThumbnail returns metadata about a specific thumbnail. +// The media could have been uploaded to this server or fetched from another server and cached here. +// Returns nil metadata if there is no metadata associated with this thumbnail. +func (d *Database) GetThumbnail( + ctx context.Context, + mediaID types.MediaID, + mediaOrigin gomatrixserverlib.ServerName, + width, height int, + resizeMethod string, +) (*types.ThumbnailMetadata, error) { + thumbnailMetadata, err := d.statements.thumbnail.selectThumbnail( + ctx, mediaID, mediaOrigin, width, height, resizeMethod, + ) + if err != nil && err == sql.ErrNoRows { + return nil, nil + } + return thumbnailMetadata, err +} + +// GetThumbnails returns metadata about all thumbnails for a specific media stored on this server. +// The media could have been uploaded to this server or fetched from another server and cached here. +// Returns nil metadata if there are no thumbnails associated with this media. +func (d *Database) GetThumbnails( + ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, +) ([]*types.ThumbnailMetadata, error) { + thumbnails, err := d.statements.thumbnail.selectThumbnails(ctx, mediaID, mediaOrigin) + if err != nil && err == sql.ErrNoRows { + return nil, nil + } + return thumbnails, err +} diff --git a/mediaapi/storage/thumbnail_table.go b/mediaapi/storage/postgres/thumbnail_table.go similarity index 98% rename from mediaapi/storage/thumbnail_table.go rename to mediaapi/storage/postgres/thumbnail_table.go index f100485f..167e3795 100644 --- a/mediaapi/storage/thumbnail_table.go +++ b/mediaapi/storage/postgres/thumbnail_table.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "context" diff --git a/mediaapi/storage/storage.go b/mediaapi/storage/storage.go index bef134a9..0f39c1d0 100644 --- a/mediaapi/storage/storage.go +++ b/mediaapi/storage/storage.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,90 +16,32 @@ package storage import ( "context" - "database/sql" + "errors" + "net/url" - // Import the postgres database driver. - _ "github.com/lib/pq" + "github.com/matrix-org/dendrite/mediaapi/storage/postgres" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" ) -// Database is used to store metadata about a repository of media files. -type Database struct { - statements statements - db *sql.DB +type Database interface { + StoreMediaMetadata(ctx context.Context, mediaMetadata *types.MediaMetadata) error + GetMediaMetadata(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) (*types.MediaMetadata, error) + StoreThumbnail(ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata) error + GetThumbnail(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, width, height int, resizeMethod string) (*types.ThumbnailMetadata, error) + GetThumbnails(ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName) ([]*types.ThumbnailMetadata, error) } // Open opens a postgres database. -func Open(dataSourceName string) (*Database, error) { - var d Database - var err error - if d.db, err = sql.Open("postgres", dataSourceName); err != nil { +func Open(dataSourceName string) (Database, error) { + uri, err := url.Parse(dataSourceName) + if err != nil { return nil, err } - if err = d.statements.prepare(d.db); err != nil { - return nil, err + switch uri.Scheme { + case "postgres": + return postgres.Open(dataSourceName) + default: + return nil, errors.New("unknown schema") } - return &d, nil -} - -// StoreMediaMetadata inserts the metadata about the uploaded media into the database. -// Returns an error if the combination of MediaID and Origin are not unique in the table. -func (d *Database) StoreMediaMetadata( - ctx context.Context, mediaMetadata *types.MediaMetadata, -) error { - return d.statements.media.insertMedia(ctx, mediaMetadata) -} - -// GetMediaMetadata returns metadata about media stored on this server. -// The media could have been uploaded to this server or fetched from another server and cached here. -// Returns nil metadata if there is no metadata associated with this media. -func (d *Database) GetMediaMetadata( - ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, -) (*types.MediaMetadata, error) { - mediaMetadata, err := d.statements.media.selectMedia(ctx, mediaID, mediaOrigin) - if err != nil && err == sql.ErrNoRows { - return nil, nil - } - return mediaMetadata, err -} - -// StoreThumbnail inserts the metadata about the thumbnail into the database. -// Returns an error if the combination of MediaID and Origin are not unique in the table. -func (d *Database) StoreThumbnail( - ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata, -) error { - return d.statements.thumbnail.insertThumbnail(ctx, thumbnailMetadata) -} - -// GetThumbnail returns metadata about a specific thumbnail. -// The media could have been uploaded to this server or fetched from another server and cached here. -// Returns nil metadata if there is no metadata associated with this thumbnail. -func (d *Database) GetThumbnail( - ctx context.Context, - mediaID types.MediaID, - mediaOrigin gomatrixserverlib.ServerName, - width, height int, - resizeMethod string, -) (*types.ThumbnailMetadata, error) { - thumbnailMetadata, err := d.statements.thumbnail.selectThumbnail( - ctx, mediaID, mediaOrigin, width, height, resizeMethod, - ) - if err != nil && err == sql.ErrNoRows { - return nil, nil - } - return thumbnailMetadata, err -} - -// GetThumbnails returns metadata about all thumbnails for a specific media stored on this server. -// The media could have been uploaded to this server or fetched from another server and cached here. -// Returns nil metadata if there are no thumbnails associated with this media. -func (d *Database) GetThumbnails( - ctx context.Context, mediaID types.MediaID, mediaOrigin gomatrixserverlib.ServerName, -) ([]*types.ThumbnailMetadata, error) { - thumbnails, err := d.statements.thumbnail.selectThumbnails(ctx, mediaID, mediaOrigin) - if err != nil && err == sql.ErrNoRows { - return nil, nil - } - return thumbnails, err } diff --git a/mediaapi/thumbnailer/thumbnailer.go b/mediaapi/thumbnailer/thumbnailer.go index 61b66ebc..ebf5138c 100644 --- a/mediaapi/thumbnailer/thumbnailer.go +++ b/mediaapi/thumbnailer/thumbnailer.go @@ -136,7 +136,7 @@ func isThumbnailExists( dst types.Path, config types.ThumbnailSize, mediaMetadata *types.MediaMetadata, - db *storage.Database, + db storage.Database, logger *log.Entry, ) (bool, error) { thumbnailMetadata, err := db.GetThumbnail( diff --git a/mediaapi/thumbnailer/thumbnailer_nfnt.go b/mediaapi/thumbnailer/thumbnailer_nfnt.go index 5df6ce4b..4f1e98aa 100644 --- a/mediaapi/thumbnailer/thumbnailer_nfnt.go +++ b/mediaapi/thumbnailer/thumbnailer_nfnt.go @@ -45,7 +45,7 @@ func GenerateThumbnails( mediaMetadata *types.MediaMetadata, activeThumbnailGeneration *types.ActiveThumbnailGeneration, maxThumbnailGenerators int, - db *storage.Database, + db storage.Database, logger *log.Entry, ) (busy bool, errorReturn error) { img, err := readFile(string(src)) @@ -78,7 +78,7 @@ func GenerateThumbnail( mediaMetadata *types.MediaMetadata, activeThumbnailGeneration *types.ActiveThumbnailGeneration, maxThumbnailGenerators int, - db *storage.Database, + db storage.Database, logger *log.Entry, ) (busy bool, errorReturn error) { img, err := readFile(string(src)) @@ -142,7 +142,7 @@ func createThumbnail( mediaMetadata *types.MediaMetadata, activeThumbnailGeneration *types.ActiveThumbnailGeneration, maxThumbnailGenerators int, - db *storage.Database, + db storage.Database, logger *log.Entry, ) (busy bool, errorReturn error) { logger = logger.WithFields(log.Fields{ diff --git a/publicroomsapi/consumers/roomserver.go b/publicroomsapi/consumers/roomserver.go index b7d42b11..9a817735 100644 --- a/publicroomsapi/consumers/roomserver.go +++ b/publicroomsapi/consumers/roomserver.go @@ -29,7 +29,7 @@ import ( // OutputRoomEventConsumer consumes events that originated in the room server. type OutputRoomEventConsumer struct { roomServerConsumer *common.ContinualConsumer - db *storage.PublicRoomsServerDatabase + db storage.Database query api.RoomserverQueryAPI } @@ -37,7 +37,7 @@ type OutputRoomEventConsumer struct { func NewOutputRoomEventConsumer( cfg *config.Dendrite, kafkaConsumer sarama.Consumer, - store *storage.PublicRoomsServerDatabase, + store storage.Database, queryAPI api.RoomserverQueryAPI, ) *OutputRoomEventConsumer { consumer := common.ContinualConsumer{ diff --git a/publicroomsapi/directory/directory.go b/publicroomsapi/directory/directory.go index 626a1c15..88981549 100644 --- a/publicroomsapi/directory/directory.go +++ b/publicroomsapi/directory/directory.go @@ -30,7 +30,7 @@ type roomVisibility struct { // GetVisibility implements GET /directory/list/room/{roomID} func GetVisibility( - req *http.Request, publicRoomsDatabase *storage.PublicRoomsServerDatabase, + req *http.Request, publicRoomsDatabase storage.Database, roomID string, ) util.JSONResponse { isPublic, err := publicRoomsDatabase.GetRoomVisibility(req.Context(), roomID) @@ -54,7 +54,7 @@ func GetVisibility( // SetVisibility implements PUT /directory/list/room/{roomID} // TODO: Check if user has the power level to edit the room visibility func SetVisibility( - req *http.Request, publicRoomsDatabase *storage.PublicRoomsServerDatabase, + req *http.Request, publicRoomsDatabase storage.Database, roomID string, ) util.JSONResponse { var v roomVisibility diff --git a/publicroomsapi/directory/public_rooms.go b/publicroomsapi/directory/public_rooms.go index ef7b2662..10aaa070 100644 --- a/publicroomsapi/directory/public_rooms.go +++ b/publicroomsapi/directory/public_rooms.go @@ -44,7 +44,7 @@ type publicRoomRes struct { // GetPostPublicRooms implements GET and POST /publicRooms func GetPostPublicRooms( - req *http.Request, publicRoomDatabase *storage.PublicRoomsServerDatabase, + req *http.Request, publicRoomDatabase storage.Database, ) util.JSONResponse { var limit int16 var offset int64 diff --git a/publicroomsapi/routing/routing.go b/publicroomsapi/routing/routing.go index 422414bc..3d2d2ac0 100644 --- a/publicroomsapi/routing/routing.go +++ b/publicroomsapi/routing/routing.go @@ -34,7 +34,7 @@ const pathPrefixR0 = "/_matrix/client/r0" // Due to Setup being used to call many other functions, a gocyclo nolint is // applied: // nolint: gocyclo -func Setup(apiMux *mux.Router, deviceDB *devices.Database, publicRoomsDB *storage.PublicRoomsServerDatabase) { +func Setup(apiMux *mux.Router, deviceDB *devices.Database, publicRoomsDB storage.Database) { r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() authData := auth.Data{ diff --git a/publicroomsapi/storage/prepare.go b/publicroomsapi/storage/postgres/prepare.go similarity index 90% rename from publicroomsapi/storage/prepare.go rename to publicroomsapi/storage/postgres/prepare.go index b1976599..70b6e516 100644 --- a/publicroomsapi/storage/prepare.go +++ b/publicroomsapi/storage/postgres/prepare.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "database/sql" diff --git a/publicroomsapi/storage/public_rooms_table.go b/publicroomsapi/storage/postgres/public_rooms_table.go similarity index 98% rename from publicroomsapi/storage/public_rooms_table.go rename to publicroomsapi/storage/postgres/public_rooms_table.go index 5e1eb3e1..852afe77 100644 --- a/publicroomsapi/storage/public_rooms_table.go +++ b/publicroomsapi/storage/postgres/public_rooms_table.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "context" diff --git a/publicroomsapi/storage/postgres/storage.go b/publicroomsapi/storage/postgres/storage.go new file mode 100644 index 00000000..67b5efc3 --- /dev/null +++ b/publicroomsapi/storage/postgres/storage.go @@ -0,0 +1,253 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/publicroomsapi/types" + + "github.com/matrix-org/gomatrixserverlib" +) + +// PublicRoomsServerDatabase represents a public rooms server database. +type PublicRoomsServerDatabase struct { + db *sql.DB + common.PartitionOffsetStatements + statements publicRoomsStatements +} + +type attributeValue interface{} + +// NewPublicRoomsServerDatabase creates a new public rooms server database. +func NewPublicRoomsServerDatabase(dataSourceName string) (*PublicRoomsServerDatabase, error) { + var db *sql.DB + var err error + if db, err = sql.Open("postgres", dataSourceName); err != nil { + return nil, err + } + partitions := common.PartitionOffsetStatements{} + if err = partitions.Prepare(db, "publicroomsapi"); err != nil { + return nil, err + } + statements := publicRoomsStatements{} + if err = statements.prepare(db); err != nil { + return nil, err + } + return &PublicRoomsServerDatabase{db, partitions, statements}, nil +} + +// GetRoomVisibility returns the room visibility as a boolean: true if the room +// is publicly visible, false if not. +// Returns an error if the retrieval failed. +func (d *PublicRoomsServerDatabase) GetRoomVisibility( + ctx context.Context, roomID string, +) (bool, error) { + return d.statements.selectRoomVisibility(ctx, roomID) +} + +// SetRoomVisibility updates the visibility attribute of a room. This attribute +// must be set to true if the room is publicly visible, false if not. +// Returns an error if the update failed. +func (d *PublicRoomsServerDatabase) SetRoomVisibility( + ctx context.Context, visible bool, roomID string, +) error { + return d.statements.updateRoomAttribute(ctx, "visibility", visible, roomID) +} + +// CountPublicRooms returns the number of room set as publicly visible on the server. +// Returns an error if the retrieval failed. +func (d *PublicRoomsServerDatabase) CountPublicRooms(ctx context.Context) (int64, error) { + return d.statements.countPublicRooms(ctx) +} + +// GetPublicRooms returns an array containing the local rooms set as publicly visible, ordered by their number +// of joined members. This array can be limited by a given number of elements, and offset by a given value. +// If the limit is 0, doesn't limit the number of results. If the offset is 0 too, the array contains all +// the rooms set as publicly visible on the server. +// Returns an error if the retrieval failed. +func (d *PublicRoomsServerDatabase) GetPublicRooms( + ctx context.Context, offset int64, limit int16, filter string, +) ([]types.PublicRoom, error) { + return d.statements.selectPublicRooms(ctx, offset, limit, filter) +} + +// UpdateRoomFromEvents iterate over a slice of state events and call +// UpdateRoomFromEvent on each of them to update the database representation of +// the rooms updated by each event. +// The slice of events to remove is used to update the number of joined members +// for the room in the database. +// If the update triggered by one of the events failed, aborts the process and +// returns an error. +func (d *PublicRoomsServerDatabase) UpdateRoomFromEvents( + ctx context.Context, + eventsToAdd []gomatrixserverlib.Event, + eventsToRemove []gomatrixserverlib.Event, +) error { + for _, event := range eventsToAdd { + if err := d.UpdateRoomFromEvent(ctx, event); err != nil { + return err + } + } + + for _, event := range eventsToRemove { + if event.Type() == "m.room.member" { + if err := d.updateNumJoinedUsers(ctx, event, true); err != nil { + return err + } + } + } + + return nil +} + +// UpdateRoomFromEvent updates the database representation of a room from a Matrix event, by +// checking the event's type to know which attribute to change and using the event's content +// to define the new value of the attribute. +// If the event doesn't match with any property used to compute the public room directory, +// does nothing. +// If something went wrong during the process, returns an error. +func (d *PublicRoomsServerDatabase) UpdateRoomFromEvent( + ctx context.Context, event gomatrixserverlib.Event, +) error { + // Process the event according to its type + switch event.Type() { + case "m.room.create": + return d.statements.insertNewRoom(ctx, event.RoomID()) + case "m.room.member": + return d.updateNumJoinedUsers(ctx, event, false) + case "m.room.aliases": + return d.updateRoomAliases(ctx, event) + case "m.room.canonical_alias": + var content common.CanonicalAliasContent + field := &(content.Alias) + attrName := "canonical_alias" + return d.updateStringAttribute(ctx, attrName, event, &content, field) + case "m.room.name": + var content common.NameContent + field := &(content.Name) + attrName := "name" + return d.updateStringAttribute(ctx, attrName, event, &content, field) + case "m.room.topic": + var content common.TopicContent + field := &(content.Topic) + attrName := "topic" + return d.updateStringAttribute(ctx, attrName, event, &content, field) + case "m.room.avatar": + var content common.AvatarContent + field := &(content.URL) + attrName := "avatar_url" + return d.updateStringAttribute(ctx, attrName, event, &content, field) + case "m.room.history_visibility": + var content common.HistoryVisibilityContent + field := &(content.HistoryVisibility) + attrName := "world_readable" + strForTrue := "world_readable" + return d.updateBooleanAttribute(ctx, attrName, event, &content, field, strForTrue) + case "m.room.guest_access": + var content common.GuestAccessContent + field := &(content.GuestAccess) + attrName := "guest_can_join" + strForTrue := "can_join" + return d.updateBooleanAttribute(ctx, attrName, event, &content, field, strForTrue) + } + + // If the event type didn't match, return with no error + return nil +} + +// updateNumJoinedUsers updates the number of joined user in the database representation +// of a room using a given "m.room.member" Matrix event. +// If the membership property of the event isn't "join", ignores it and returs nil. +// If the remove parameter is set to false, increments the joined members counter in the +// database, if set to truem decrements it. +// Returns an error if the update failed. +func (d *PublicRoomsServerDatabase) updateNumJoinedUsers( + ctx context.Context, membershipEvent gomatrixserverlib.Event, remove bool, +) error { + membership, err := membershipEvent.Membership() + if err != nil { + return err + } + + if membership != gomatrixserverlib.Join { + return nil + } + + if remove { + return d.statements.decrementJoinedMembersInRoom(ctx, membershipEvent.RoomID()) + } + return d.statements.incrementJoinedMembersInRoom(ctx, membershipEvent.RoomID()) +} + +// updateStringAttribute updates a given string attribute in the database +// representation of a room using a given string data field from content of the +// Matrix event triggering the update. +// Returns an error if decoding the Matrix event's content or updating the attribute +// failed. +func (d *PublicRoomsServerDatabase) updateStringAttribute( + ctx context.Context, attrName string, event gomatrixserverlib.Event, + content interface{}, field *string, +) error { + if err := json.Unmarshal(event.Content(), content); err != nil { + return err + } + + return d.statements.updateRoomAttribute(ctx, attrName, *field, event.RoomID()) +} + +// updateBooleanAttribute updates a given boolean attribute in the database +// representation of a room using a given string data field from content of the +// Matrix event triggering the update. +// The attribute is set to true if the field matches a given string, false if not. +// Returns an error if decoding the Matrix event's content or updating the attribute +// failed. +func (d *PublicRoomsServerDatabase) updateBooleanAttribute( + ctx context.Context, attrName string, event gomatrixserverlib.Event, + content interface{}, field *string, strForTrue string, +) error { + if err := json.Unmarshal(event.Content(), content); err != nil { + return err + } + + var attrValue bool + if *field == strForTrue { + attrValue = true + } else { + attrValue = false + } + + return d.statements.updateRoomAttribute(ctx, attrName, attrValue, event.RoomID()) +} + +// updateRoomAliases decodes the content of a "m.room.aliases" Matrix event and update the list of aliases of +// a given room with it. +// Returns an error if decoding the Matrix event or updating the list failed. +func (d *PublicRoomsServerDatabase) updateRoomAliases( + ctx context.Context, aliasesEvent gomatrixserverlib.Event, +) error { + var content common.AliasesContent + if err := json.Unmarshal(aliasesEvent.Content(), &content); err != nil { + return err + } + + return d.statements.updateRoomAttribute( + ctx, "aliases", content.Aliases, aliasesEvent.RoomID(), + ) +} diff --git a/publicroomsapi/storage/storage.go b/publicroomsapi/storage/storage.go index aa980694..a6a39d52 100644 --- a/publicroomsapi/storage/storage.go +++ b/publicroomsapi/storage/storage.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,237 +16,35 @@ package storage import ( "context" - "database/sql" - "encoding/json" + "errors" + "net/url" "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/publicroomsapi/storage/postgres" "github.com/matrix-org/dendrite/publicroomsapi/types" - "github.com/matrix-org/gomatrixserverlib" ) -// PublicRoomsServerDatabase represents a public rooms server database. -type PublicRoomsServerDatabase struct { - db *sql.DB - common.PartitionOffsetStatements - statements publicRoomsStatements +type Database interface { + common.PartitionStorer + GetRoomVisibility(ctx context.Context, roomID string) (bool, error) + SetRoomVisibility(ctx context.Context, visible bool, roomID string) error + CountPublicRooms(ctx context.Context) (int64, error) + GetPublicRooms(ctx context.Context, offset int64, limit int16, filter string) ([]types.PublicRoom, error) + UpdateRoomFromEvents(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, eventsToRemove []gomatrixserverlib.Event) error + UpdateRoomFromEvent(ctx context.Context, event gomatrixserverlib.Event) error } -type attributeValue interface{} - -// NewPublicRoomsServerDatabase creates a new public rooms server database. -func NewPublicRoomsServerDatabase(dataSourceName string) (*PublicRoomsServerDatabase, error) { - var db *sql.DB - var err error - if db, err = sql.Open("postgres", dataSourceName); err != nil { - return nil, err - } - partitions := common.PartitionOffsetStatements{} - if err = partitions.Prepare(db, "publicroomsapi"); err != nil { - return nil, err - } - statements := publicRoomsStatements{} - if err = statements.prepare(db); err != nil { - return nil, err - } - return &PublicRoomsServerDatabase{db, partitions, statements}, nil -} - -// GetRoomVisibility returns the room visibility as a boolean: true if the room -// is publicly visible, false if not. -// Returns an error if the retrieval failed. -func (d *PublicRoomsServerDatabase) GetRoomVisibility( - ctx context.Context, roomID string, -) (bool, error) { - return d.statements.selectRoomVisibility(ctx, roomID) -} - -// SetRoomVisibility updates the visibility attribute of a room. This attribute -// must be set to true if the room is publicly visible, false if not. -// Returns an error if the update failed. -func (d *PublicRoomsServerDatabase) SetRoomVisibility( - ctx context.Context, visible bool, roomID string, -) error { - return d.statements.updateRoomAttribute(ctx, "visibility", visible, roomID) -} - -// CountPublicRooms returns the number of room set as publicly visible on the server. -// Returns an error if the retrieval failed. -func (d *PublicRoomsServerDatabase) CountPublicRooms(ctx context.Context) (int64, error) { - return d.statements.countPublicRooms(ctx) -} - -// GetPublicRooms returns an array containing the local rooms set as publicly visible, ordered by their number -// of joined members. This array can be limited by a given number of elements, and offset by a given value. -// If the limit is 0, doesn't limit the number of results. If the offset is 0 too, the array contains all -// the rooms set as publicly visible on the server. -// Returns an error if the retrieval failed. -func (d *PublicRoomsServerDatabase) GetPublicRooms( - ctx context.Context, offset int64, limit int16, filter string, -) ([]types.PublicRoom, error) { - return d.statements.selectPublicRooms(ctx, offset, limit, filter) -} - -// UpdateRoomFromEvents iterate over a slice of state events and call -// UpdateRoomFromEvent on each of them to update the database representation of -// the rooms updated by each event. -// The slice of events to remove is used to update the number of joined members -// for the room in the database. -// If the update triggered by one of the events failed, aborts the process and -// returns an error. -func (d *PublicRoomsServerDatabase) UpdateRoomFromEvents( - ctx context.Context, - eventsToAdd []gomatrixserverlib.Event, - eventsToRemove []gomatrixserverlib.Event, -) error { - for _, event := range eventsToAdd { - if err := d.UpdateRoomFromEvent(ctx, event); err != nil { - return err - } - } - - for _, event := range eventsToRemove { - if event.Type() == "m.room.member" { - if err := d.updateNumJoinedUsers(ctx, event, true); err != nil { - return err - } - } - } - - return nil -} - -// UpdateRoomFromEvent updates the database representation of a room from a Matrix event, by -// checking the event's type to know which attribute to change and using the event's content -// to define the new value of the attribute. -// If the event doesn't match with any property used to compute the public room directory, -// does nothing. -// If something went wrong during the process, returns an error. -func (d *PublicRoomsServerDatabase) UpdateRoomFromEvent( - ctx context.Context, event gomatrixserverlib.Event, -) error { - // Process the event according to its type - switch event.Type() { - case "m.room.create": - return d.statements.insertNewRoom(ctx, event.RoomID()) - case "m.room.member": - return d.updateNumJoinedUsers(ctx, event, false) - case "m.room.aliases": - return d.updateRoomAliases(ctx, event) - case "m.room.canonical_alias": - var content common.CanonicalAliasContent - field := &(content.Alias) - attrName := "canonical_alias" - return d.updateStringAttribute(ctx, attrName, event, &content, field) - case "m.room.name": - var content common.NameContent - field := &(content.Name) - attrName := "name" - return d.updateStringAttribute(ctx, attrName, event, &content, field) - case "m.room.topic": - var content common.TopicContent - field := &(content.Topic) - attrName := "topic" - return d.updateStringAttribute(ctx, attrName, event, &content, field) - case "m.room.avatar": - var content common.AvatarContent - field := &(content.URL) - attrName := "avatar_url" - return d.updateStringAttribute(ctx, attrName, event, &content, field) - case "m.room.history_visibility": - var content common.HistoryVisibilityContent - field := &(content.HistoryVisibility) - attrName := "world_readable" - strForTrue := "world_readable" - return d.updateBooleanAttribute(ctx, attrName, event, &content, field, strForTrue) - case "m.room.guest_access": - var content common.GuestAccessContent - field := &(content.GuestAccess) - attrName := "guest_can_join" - strForTrue := "can_join" - return d.updateBooleanAttribute(ctx, attrName, event, &content, field, strForTrue) - } - - // If the event type didn't match, return with no error - return nil -} - -// updateNumJoinedUsers updates the number of joined user in the database representation -// of a room using a given "m.room.member" Matrix event. -// If the membership property of the event isn't "join", ignores it and returs nil. -// If the remove parameter is set to false, increments the joined members counter in the -// database, if set to truem decrements it. -// Returns an error if the update failed. -func (d *PublicRoomsServerDatabase) updateNumJoinedUsers( - ctx context.Context, membershipEvent gomatrixserverlib.Event, remove bool, -) error { - membership, err := membershipEvent.Membership() +// NewPublicRoomsServerDatabase opens a database connection. +func NewPublicRoomsServerDatabase(dataSourceName string) (Database, error) { + uri, err := url.Parse(dataSourceName) if err != nil { - return err + return nil, err } - - if membership != gomatrixserverlib.Join { - return nil + switch uri.Scheme { + case "postgres": + return postgres.NewPublicRoomsServerDatabase(dataSourceName) + default: + return nil, errors.New("unknown schema") } - - if remove { - return d.statements.decrementJoinedMembersInRoom(ctx, membershipEvent.RoomID()) - } - return d.statements.incrementJoinedMembersInRoom(ctx, membershipEvent.RoomID()) -} - -// updateStringAttribute updates a given string attribute in the database -// representation of a room using a given string data field from content of the -// Matrix event triggering the update. -// Returns an error if decoding the Matrix event's content or updating the attribute -// failed. -func (d *PublicRoomsServerDatabase) updateStringAttribute( - ctx context.Context, attrName string, event gomatrixserverlib.Event, - content interface{}, field *string, -) error { - if err := json.Unmarshal(event.Content(), content); err != nil { - return err - } - - return d.statements.updateRoomAttribute(ctx, attrName, *field, event.RoomID()) -} - -// updateBooleanAttribute updates a given boolean attribute in the database -// representation of a room using a given string data field from content of the -// Matrix event triggering the update. -// The attribute is set to true if the field matches a given string, false if not. -// Returns an error if decoding the Matrix event's content or updating the attribute -// failed. -func (d *PublicRoomsServerDatabase) updateBooleanAttribute( - ctx context.Context, attrName string, event gomatrixserverlib.Event, - content interface{}, field *string, strForTrue string, -) error { - if err := json.Unmarshal(event.Content(), content); err != nil { - return err - } - - var attrValue bool - if *field == strForTrue { - attrValue = true - } else { - attrValue = false - } - - return d.statements.updateRoomAttribute(ctx, attrName, attrValue, event.RoomID()) -} - -// updateRoomAliases decodes the content of a "m.room.aliases" Matrix event and update the list of aliases of -// a given room with it. -// Returns an error if decoding the Matrix event or updating the list failed. -func (d *PublicRoomsServerDatabase) updateRoomAliases( - ctx context.Context, aliasesEvent gomatrixserverlib.Event, -) error { - var content common.AliasesContent - if err := json.Unmarshal(aliasesEvent.Content(), &content); err != nil { - return err - } - - return d.statements.updateRoomAttribute( - ctx, "aliases", content.Aliases, aliasesEvent.RoomID(), - ) } diff --git a/roomserver/storage/event_json_table.go b/roomserver/storage/postgres/event_json_table.go similarity index 96% rename from roomserver/storage/event_json_table.go rename to roomserver/storage/postgres/event_json_table.go index b81667d9..415fb84e 100644 --- a/roomserver/storage/event_json_table.go +++ b/roomserver/storage/postgres/event_json_table.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "context" diff --git a/roomserver/storage/event_state_keys_table.go b/roomserver/storage/postgres/event_state_keys_table.go similarity index 98% rename from roomserver/storage/event_state_keys_table.go rename to roomserver/storage/postgres/event_state_keys_table.go index 1ef93370..c3aaa498 100644 --- a/roomserver/storage/event_state_keys_table.go +++ b/roomserver/storage/postgres/event_state_keys_table.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "context" diff --git a/roomserver/storage/event_types_table.go b/roomserver/storage/postgres/event_types_table.go similarity index 98% rename from roomserver/storage/event_types_table.go rename to roomserver/storage/postgres/event_types_table.go index 7b8d53a5..1ec2e7cd 100644 --- a/roomserver/storage/event_types_table.go +++ b/roomserver/storage/postgres/event_types_table.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "context" diff --git a/roomserver/storage/events_table.go b/roomserver/storage/postgres/events_table.go similarity index 99% rename from roomserver/storage/events_table.go rename to roomserver/storage/postgres/events_table.go index 5bad939f..1e8a5665 100644 --- a/roomserver/storage/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "context" diff --git a/roomserver/storage/invite_table.go b/roomserver/storage/postgres/invite_table.go similarity index 97% rename from roomserver/storage/invite_table.go rename to roomserver/storage/postgres/invite_table.go index 4f9cdfb4..43cd5ba0 100644 --- a/roomserver/storage/invite_table.go +++ b/roomserver/storage/postgres/invite_table.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "context" diff --git a/roomserver/storage/membership_table.go b/roomserver/storage/postgres/membership_table.go similarity index 98% rename from roomserver/storage/membership_table.go rename to roomserver/storage/postgres/membership_table.go index 88a9ed72..9f41fd67 100644 --- a/roomserver/storage/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "context" diff --git a/roomserver/storage/prepare.go b/roomserver/storage/postgres/prepare.go similarity index 90% rename from roomserver/storage/prepare.go rename to roomserver/storage/postgres/prepare.go index b1976599..70b6e516 100644 --- a/roomserver/storage/prepare.go +++ b/roomserver/storage/postgres/prepare.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "database/sql" diff --git a/roomserver/storage/previous_events_table.go b/roomserver/storage/postgres/previous_events_table.go similarity index 97% rename from roomserver/storage/previous_events_table.go rename to roomserver/storage/postgres/previous_events_table.go index 81d581a9..4c21b308 100644 --- a/roomserver/storage/previous_events_table.go +++ b/roomserver/storage/postgres/previous_events_table.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "context" diff --git a/roomserver/storage/room_aliases_table.go b/roomserver/storage/postgres/room_aliases_table.go similarity index 96% rename from roomserver/storage/room_aliases_table.go rename to roomserver/storage/postgres/room_aliases_table.go index 3ed20e8e..ad1b560c 100644 --- a/roomserver/storage/room_aliases_table.go +++ b/roomserver/storage/postgres/room_aliases_table.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "context" diff --git a/roomserver/storage/rooms_table.go b/roomserver/storage/postgres/rooms_table.go similarity index 97% rename from roomserver/storage/rooms_table.go rename to roomserver/storage/postgres/rooms_table.go index 64193ffe..ccc201b1 100644 --- a/roomserver/storage/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "context" diff --git a/roomserver/storage/sql.go b/roomserver/storage/postgres/sql.go similarity index 92% rename from roomserver/storage/sql.go rename to roomserver/storage/postgres/sql.go index 05efa8dd..5956886c 100644 --- a/roomserver/storage/sql.go +++ b/roomserver/storage/postgres/sql.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "database/sql" diff --git a/roomserver/storage/state_block_table.go b/roomserver/storage/postgres/state_block_table.go similarity index 98% rename from roomserver/storage/state_block_table.go rename to roomserver/storage/postgres/state_block_table.go index b2e8ef8a..15e69cc9 100644 --- a/roomserver/storage/state_block_table.go +++ b/roomserver/storage/postgres/state_block_table.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "context" diff --git a/roomserver/storage/state_block_table_test.go b/roomserver/storage/postgres/state_block_table_test.go similarity index 95% rename from roomserver/storage/state_block_table_test.go rename to roomserver/storage/postgres/state_block_table_test.go index f891b5bc..a0e2ec95 100644 --- a/roomserver/storage/state_block_table_test.go +++ b/roomserver/storage/postgres/state_block_table_test.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "sort" diff --git a/roomserver/storage/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go similarity index 97% rename from roomserver/storage/state_snapshot_table.go rename to roomserver/storage/postgres/state_snapshot_table.go index aa14daad..76f1d2b6 100644 --- a/roomserver/storage/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "context" diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go new file mode 100644 index 00000000..93450e5a --- /dev/null +++ b/roomserver/storage/postgres/storage.go @@ -0,0 +1,713 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + // Import the postgres database driver. + _ "github.com/lib/pq" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" +) + +// A Database is used to store room events and stream offsets. +type Database struct { + statements statements + db *sql.DB +} + +// Open a postgres database. +func Open(dataSourceName string) (*Database, error) { + var d Database + var err error + if d.db, err = sql.Open("postgres", dataSourceName); err != nil { + return nil, err + } + if err = d.statements.prepare(d.db); err != nil { + return nil, err + } + return &d, nil +} + +// StoreEvent implements input.EventDatabase +func (d *Database) StoreEvent( + ctx context.Context, event gomatrixserverlib.Event, + txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, +) (types.RoomNID, types.StateAtEvent, error) { + var ( + roomNID types.RoomNID + eventTypeNID types.EventTypeNID + eventStateKeyNID types.EventStateKeyNID + eventNID types.EventNID + stateNID types.StateSnapshotNID + err error + ) + + if txnAndSessionID != nil { + if err = d.statements.insertTransaction( + ctx, txnAndSessionID.TransactionID, + txnAndSessionID.SessionID, event.Sender(), event.EventID(), + ); err != nil { + return 0, types.StateAtEvent{}, err + } + } + + if roomNID, err = d.assignRoomNID(ctx, nil, event.RoomID()); err != nil { + return 0, types.StateAtEvent{}, err + } + + if eventTypeNID, err = d.assignEventTypeNID(ctx, event.Type()); err != nil { + return 0, types.StateAtEvent{}, err + } + + eventStateKey := event.StateKey() + // Assigned a numeric ID for the state_key if there is one present. + // Otherwise set the numeric ID for the state_key to 0. + if eventStateKey != nil { + if eventStateKeyNID, err = d.assignStateKeyNID(ctx, nil, *eventStateKey); err != nil { + return 0, types.StateAtEvent{}, err + } + } + + if eventNID, stateNID, err = d.statements.insertEvent( + ctx, + roomNID, + eventTypeNID, + eventStateKeyNID, + event.EventID(), + event.EventReference().EventSHA256, + authEventNIDs, + event.Depth(), + ); err != nil { + if err == sql.ErrNoRows { + // We've already inserted the event so select the numeric event ID + eventNID, stateNID, err = d.statements.selectEvent(ctx, event.EventID()) + } + if err != nil { + return 0, types.StateAtEvent{}, err + } + } + + if err = d.statements.insertEventJSON(ctx, eventNID, event.JSON()); err != nil { + return 0, types.StateAtEvent{}, err + } + + return roomNID, types.StateAtEvent{ + BeforeStateSnapshotNID: stateNID, + StateEntry: types.StateEntry{ + StateKeyTuple: types.StateKeyTuple{ + EventTypeNID: eventTypeNID, + EventStateKeyNID: eventStateKeyNID, + }, + EventNID: eventNID, + }, + }, nil +} + +func (d *Database) assignRoomNID( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { + // Check if we already have a numeric ID in the database. + roomNID, err := d.statements.selectRoomNID(ctx, txn, roomID) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) + } + } + return roomNID, err +} + +func (d *Database) assignEventTypeNID( + ctx context.Context, eventType string, +) (types.EventTypeNID, error) { + // Check if we already have a numeric ID in the database. + eventTypeNID, err := d.statements.selectEventTypeNID(ctx, eventType) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + eventTypeNID, err = d.statements.insertEventTypeNID(ctx, eventType) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + eventTypeNID, err = d.statements.selectEventTypeNID(ctx, eventType) + } + } + return eventTypeNID, err +} + +func (d *Database) assignStateKeyNID( + ctx context.Context, txn *sql.Tx, eventStateKey string, +) (types.EventStateKeyNID, error) { + // Check if we already have a numeric ID in the database. + eventStateKeyNID, err := d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) + if err == sql.ErrNoRows { + // We don't have a numeric ID so insert one into the database. + eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey) + if err == sql.ErrNoRows { + // We raced with another insert so run the select again. + eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) + } + } + return eventStateKeyNID, err +} + +// StateEntriesForEventIDs implements input.EventDatabase +func (d *Database) StateEntriesForEventIDs( + ctx context.Context, eventIDs []string, +) ([]types.StateEntry, error) { + return d.statements.bulkSelectStateEventByID(ctx, eventIDs) +} + +// EventTypeNIDs implements state.RoomStateDatabase +func (d *Database) EventTypeNIDs( + ctx context.Context, eventTypes []string, +) (map[string]types.EventTypeNID, error) { + return d.statements.bulkSelectEventTypeNID(ctx, eventTypes) +} + +// EventStateKeyNIDs implements state.RoomStateDatabase +func (d *Database) EventStateKeyNIDs( + ctx context.Context, eventStateKeys []string, +) (map[string]types.EventStateKeyNID, error) { + return d.statements.bulkSelectEventStateKeyNID(ctx, eventStateKeys) +} + +// EventStateKeys implements query.RoomserverQueryAPIDatabase +func (d *Database) EventStateKeys( + ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, +) (map[types.EventStateKeyNID]string, error) { + return d.statements.bulkSelectEventStateKey(ctx, eventStateKeyNIDs) +} + +// EventNIDs implements query.RoomserverQueryAPIDatabase +func (d *Database) EventNIDs( + ctx context.Context, eventIDs []string, +) (map[string]types.EventNID, error) { + return d.statements.bulkSelectEventNID(ctx, eventIDs) +} + +// Events implements input.EventDatabase +func (d *Database) Events( + ctx context.Context, eventNIDs []types.EventNID, +) ([]types.Event, error) { + eventJSONs, err := d.statements.bulkSelectEventJSON(ctx, eventNIDs) + if err != nil { + return nil, err + } + results := make([]types.Event, len(eventJSONs)) + for i, eventJSON := range eventJSONs { + result := &results[i] + result.EventNID = eventJSON.EventNID + // TODO: Use NewEventFromTrustedJSON for efficiency + result.Event, err = gomatrixserverlib.NewEventFromUntrustedJSON(eventJSON.EventJSON) + if err != nil { + return nil, err + } + } + return results, nil +} + +// AddState implements input.EventDatabase +func (d *Database) AddState( + ctx context.Context, + roomNID types.RoomNID, + stateBlockNIDs []types.StateBlockNID, + state []types.StateEntry, +) (types.StateSnapshotNID, error) { + if len(state) > 0 { + stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx) + if err != nil { + return 0, err + } + if err = d.statements.bulkInsertStateData(ctx, stateBlockNID, state); err != nil { + return 0, err + } + stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) + } + + return d.statements.insertState(ctx, roomNID, stateBlockNIDs) +} + +// SetState implements input.EventDatabase +func (d *Database) SetState( + ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, +) error { + return d.statements.updateEventState(ctx, eventNID, stateNID) +} + +// StateAtEventIDs implements input.EventDatabase +func (d *Database) StateAtEventIDs( + ctx context.Context, eventIDs []string, +) ([]types.StateAtEvent, error) { + return d.statements.bulkSelectStateAtEventByID(ctx, eventIDs) +} + +// StateBlockNIDs implements state.RoomStateDatabase +func (d *Database) StateBlockNIDs( + ctx context.Context, stateNIDs []types.StateSnapshotNID, +) ([]types.StateBlockNIDList, error) { + return d.statements.bulkSelectStateBlockNIDs(ctx, stateNIDs) +} + +// StateEntries implements state.RoomStateDatabase +func (d *Database) StateEntries( + ctx context.Context, stateBlockNIDs []types.StateBlockNID, +) ([]types.StateEntryList, error) { + return d.statements.bulkSelectStateBlockEntries(ctx, stateBlockNIDs) +} + +// SnapshotNIDFromEventID implements state.RoomStateDatabase +func (d *Database) SnapshotNIDFromEventID( + ctx context.Context, eventID string, +) (types.StateSnapshotNID, error) { + _, stateNID, err := d.statements.selectEvent(ctx, eventID) + return stateNID, err +} + +// EventIDs implements input.RoomEventDatabase +func (d *Database) EventIDs( + ctx context.Context, eventNIDs []types.EventNID, +) (map[types.EventNID]string, error) { + return d.statements.bulkSelectEventID(ctx, eventNIDs) +} + +// GetLatestEventsForUpdate implements input.EventDatabase +func (d *Database) GetLatestEventsForUpdate( + ctx context.Context, roomNID types.RoomNID, +) (types.RoomRecentEventsUpdater, error) { + txn, err := d.db.Begin() + if err != nil { + return nil, err + } + eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := + d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID) + if err != nil { + txn.Rollback() // nolint: errcheck + return nil, err + } + stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) + if err != nil { + txn.Rollback() // nolint: errcheck + return nil, err + } + var lastEventIDSent string + if lastEventNIDSent != 0 { + lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent) + if err != nil { + txn.Rollback() // nolint: errcheck + return nil, err + } + } + return &roomRecentEventsUpdater{ + transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, + }, nil +} + +// GetTransactionEventID implements input.EventDatabase +func (d *Database) GetTransactionEventID( + ctx context.Context, transactionID string, + sessionID int64, userID string, +) (string, error) { + eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, sessionID, userID) + if err == sql.ErrNoRows { + return "", nil + } + return eventID, err +} + +type roomRecentEventsUpdater struct { + transaction + d *Database + roomNID types.RoomNID + latestEvents []types.StateAtEventAndReference + lastEventIDSent string + currentStateSnapshotNID types.StateSnapshotNID +} + +// LatestEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) LatestEvents() []types.StateAtEventAndReference { + return u.latestEvents +} + +// LastEventIDSent implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) LastEventIDSent() string { + return u.lastEventIDSent +} + +// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { + return u.currentStateSnapshotNID +} + +// StorePreviousEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { + for _, ref := range previousEventReferences { + if err := u.d.statements.insertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + return err + } + } + return nil +} + +// IsReferenced implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { + err := u.d.statements.selectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) + if err == nil { + return true, nil + } + if err == sql.ErrNoRows { + return false, nil + } + return false, err +} + +// SetLatestEvents implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) SetLatestEvents( + roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, + currentStateSnapshotNID types.StateSnapshotNID, +) error { + eventNIDs := make([]types.EventNID, len(latest)) + for i := range latest { + eventNIDs[i] = latest[i].EventNID + } + return u.d.statements.updateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) +} + +// HasEventBeenSent implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { + return u.d.statements.selectEventSentToOutput(u.ctx, u.txn, eventNID) +} + +// MarkEventAsSent implements types.RoomRecentEventsUpdater +func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { + return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID) +} + +func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) { + return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID) +} + +// RoomNID implements query.RoomserverQueryAPIDB +func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) { + roomNID, err := d.statements.selectRoomNID(ctx, nil, roomID) + if err == sql.ErrNoRows { + return 0, nil + } + return roomNID, err +} + +// LatestEventIDs implements query.RoomserverQueryAPIDatabase +func (d *Database) LatestEventIDs( + ctx context.Context, roomNID types.RoomNID, +) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) { + eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(ctx, roomNID) + if err != nil { + return nil, 0, 0, err + } + references, err := d.statements.bulkSelectEventReference(ctx, eventNIDs) + if err != nil { + return nil, 0, 0, err + } + depth, err := d.statements.selectMaxEventDepth(ctx, eventNIDs) + if err != nil { + return nil, 0, 0, err + } + return references, currentStateSnapshotNID, depth, nil +} + +// GetInvitesForUser implements query.RoomserverQueryAPIDatabase +func (d *Database) GetInvitesForUser( + ctx context.Context, + roomNID types.RoomNID, + targetUserNID types.EventStateKeyNID, +) (senderUserIDs []types.EventStateKeyNID, err error) { + return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) +} + +// SetRoomAlias implements alias.RoomserverAliasAPIDB +func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { + return d.statements.insertRoomAlias(ctx, alias, roomID, creatorUserID) +} + +// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB +func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { + return d.statements.selectRoomIDFromAlias(ctx, alias) +} + +// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB +func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { + return d.statements.selectAliasesFromRoomID(ctx, roomID) +} + +// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB +func (d *Database) GetCreatorIDForAlias( + ctx context.Context, alias string, +) (string, error) { + return d.statements.selectCreatorIDFromAlias(ctx, alias) +} + +// RemoveRoomAlias implements alias.RoomserverAliasAPIDB +func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { + return d.statements.deleteRoomAlias(ctx, alias) +} + +// StateEntriesForTuples implements state.RoomStateDatabase +func (d *Database) StateEntriesForTuples( + ctx context.Context, + stateBlockNIDs []types.StateBlockNID, + stateKeyTuples []types.StateKeyTuple, +) ([]types.StateEntryList, error) { + return d.statements.bulkSelectFilteredStateBlockEntries( + ctx, stateBlockNIDs, stateKeyTuples, + ) +} + +// MembershipUpdater implements input.RoomEventDatabase +func (d *Database) MembershipUpdater( + ctx context.Context, roomID, targetUserID string, +) (types.MembershipUpdater, error) { + txn, err := d.db.Begin() + if err != nil { + return nil, err + } + succeeded := false + defer func() { + if !succeeded { + txn.Rollback() // nolint: errcheck + } + }() + + roomNID, err := d.assignRoomNID(ctx, txn, roomID) + if err != nil { + return nil, err + } + + targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID) + if err != nil { + return nil, err + } + + updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID) + if err != nil { + return nil, err + } + + succeeded = true + return updater, nil +} + +type membershipUpdater struct { + transaction + d *Database + roomNID types.RoomNID + targetUserNID types.EventStateKeyNID + membership membershipState +} + +func (d *Database) membershipUpdaterTxn( + ctx context.Context, + txn *sql.Tx, + roomNID types.RoomNID, + targetUserNID types.EventStateKeyNID, +) (types.MembershipUpdater, error) { + + if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil { + return nil, err + } + + membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) + if err != nil { + return nil, err + } + + return &membershipUpdater{ + transaction{ctx, txn}, d, roomNID, targetUserNID, membership, + }, nil +} + +// IsInvite implements types.MembershipUpdater +func (u *membershipUpdater) IsInvite() bool { + return u.membership == membershipStateInvite +} + +// IsJoin implements types.MembershipUpdater +func (u *membershipUpdater) IsJoin() bool { + return u.membership == membershipStateJoin +} + +// IsLeave implements types.MembershipUpdater +func (u *membershipUpdater) IsLeave() bool { + return u.membership == membershipStateLeaveOrBan +} + +// SetToInvite implements types.MembershipUpdater +func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) + if err != nil { + return false, err + } + inserted, err := u.d.statements.insertInviteEvent( + u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), + ) + if err != nil { + return false, err + } + if u.membership != membershipStateInvite { + if err = u.d.statements.updateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, + ); err != nil { + return false, err + } + } + return inserted, nil +} + +// SetToJoin implements types.MembershipUpdater +func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) { + var inviteEventIDs []string + + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) + if err != nil { + return nil, err + } + + // If this is a join event update, there is no invite to update + if !isUpdate { + inviteEventIDs, err = u.d.statements.updateInviteRetired( + u.ctx, u.txn, u.roomNID, u.targetUserNID, + ) + if err != nil { + return nil, err + } + } + + // Look up the NID of the new join event + nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + if err != nil { + return nil, err + } + + if u.membership != membershipStateJoin || isUpdate { + if err = u.d.statements.updateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, + membershipStateJoin, nIDs[eventID], + ); err != nil { + return nil, err + } + } + + return inviteEventIDs, nil +} + +// SetToLeave implements types.MembershipUpdater +func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) { + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) + if err != nil { + return nil, err + } + inviteEventIDs, err := u.d.statements.updateInviteRetired( + u.ctx, u.txn, u.roomNID, u.targetUserNID, + ) + if err != nil { + return nil, err + } + + // Look up the NID of the new leave event + nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) + if err != nil { + return nil, err + } + + if u.membership != membershipStateLeaveOrBan { + if err = u.d.statements.updateMembership( + u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, + membershipStateLeaveOrBan, nIDs[eventID], + ); err != nil { + return nil, err + } + } + return inviteEventIDs, nil +} + +// GetMembership implements query.RoomserverQueryAPIDB +func (d *Database) GetMembership( + ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, +) (membershipEventNID types.EventNID, stillInRoom bool, err error) { + requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID) + if err != nil { + return + } + + senderMembershipEventNID, senderMembership, err := + d.statements.selectMembershipFromRoomAndTarget( + ctx, roomNID, requestSenderUserNID, + ) + if err == sql.ErrNoRows { + // The user has never been a member of that room + return 0, false, nil + } else if err != nil { + return + } + + return senderMembershipEventNID, senderMembership == membershipStateJoin, nil +} + +// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB +func (d *Database) GetMembershipEventNIDsForRoom( + ctx context.Context, roomNID types.RoomNID, joinOnly bool, +) ([]types.EventNID, error) { + if joinOnly { + return d.statements.selectMembershipsFromRoomAndMembership( + ctx, roomNID, membershipStateJoin, + ) + } + + return d.statements.selectMembershipsFromRoom(ctx, roomNID) +} + +// EventsFromIDs implements query.RoomserverQueryAPIEventDB +func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { + nidMap, err := d.EventNIDs(ctx, eventIDs) + if err != nil { + return nil, err + } + + var nids []types.EventNID + for _, nid := range nidMap { + nids = append(nids, nid) + } + + return d.Events(ctx, nids) +} + +type transaction struct { + ctx context.Context + txn *sql.Tx +} + +// Commit implements types.Transaction +func (t *transaction) Commit() error { + return t.txn.Commit() +} + +// Rollback implements types.Transaction +func (t *transaction) Rollback() error { + return t.txn.Rollback() +} diff --git a/roomserver/storage/transactions_table.go b/roomserver/storage/postgres/transactions_table.go similarity index 95% rename from roomserver/storage/transactions_table.go rename to roomserver/storage/postgres/transactions_table.go index b98ea3f3..87c1caca 100644 --- a/roomserver/storage/transactions_table.go +++ b/roomserver/storage/postgres/transactions_table.go @@ -1,3 +1,6 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -10,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "context" diff --git a/roomserver/storage/storage.go b/roomserver/storage/storage.go index 7e8eb98c..325d96e9 100644 --- a/roomserver/storage/storage.go +++ b/roomserver/storage/storage.go @@ -1,4 +1,4 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,697 +16,57 @@ package storage import ( "context" - "database/sql" + "errors" + "net/url" - // Import the postgres database driver. - _ "github.com/lib/pq" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/storage/postgres" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" ) -// A Database is used to store room events and stream offsets. -type Database struct { - statements statements - db *sql.DB +type Database interface { + StoreEvent(ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error) + StateEntriesForEventIDs(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) + EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error) + EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) + EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) + EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) + Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) + AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) + SetState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error + StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) + StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) + StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) + SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error) + EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) + GetLatestEventsForUpdate(ctx context.Context, roomNID types.RoomNID) (types.RoomRecentEventsUpdater, error) + GetTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (string, error) + RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) + LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) + GetInvitesForUser(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserIDs []types.EventStateKeyNID, err error) + SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error + GetRoomIDForAlias(ctx context.Context, alias string) (string, error) + GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) + GetCreatorIDForAlias(ctx context.Context, alias string) (string, error) + RemoveRoomAlias(ctx context.Context, alias string) error + StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) + MembershipUpdater(ctx context.Context, roomID, targetUserID string) (types.MembershipUpdater, error) + GetMembership(ctx context.Context, roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error) + GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error) + EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) } -// Open a postgres database. -func Open(dataSourceName string) (*Database, error) { - var d Database - var err error - if d.db, err = sql.Open("postgres", dataSourceName); err != nil { - return nil, err - } - if err = d.statements.prepare(d.db); err != nil { - return nil, err - } - return &d, nil -} - -// StoreEvent implements input.EventDatabase -func (d *Database) StoreEvent( - ctx context.Context, event gomatrixserverlib.Event, - txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, -) (types.RoomNID, types.StateAtEvent, error) { - var ( - roomNID types.RoomNID - eventTypeNID types.EventTypeNID - eventStateKeyNID types.EventStateKeyNID - eventNID types.EventNID - stateNID types.StateSnapshotNID - err error - ) - - if txnAndSessionID != nil { - if err = d.statements.insertTransaction( - ctx, txnAndSessionID.TransactionID, - txnAndSessionID.SessionID, event.Sender(), event.EventID(), - ); err != nil { - return 0, types.StateAtEvent{}, err - } - } - - if roomNID, err = d.assignRoomNID(ctx, nil, event.RoomID()); err != nil { - return 0, types.StateAtEvent{}, err - } - - if eventTypeNID, err = d.assignEventTypeNID(ctx, event.Type()); err != nil { - return 0, types.StateAtEvent{}, err - } - - eventStateKey := event.StateKey() - // Assigned a numeric ID for the state_key if there is one present. - // Otherwise set the numeric ID for the state_key to 0. - if eventStateKey != nil { - if eventStateKeyNID, err = d.assignStateKeyNID(ctx, nil, *eventStateKey); err != nil { - return 0, types.StateAtEvent{}, err - } - } - - if eventNID, stateNID, err = d.statements.insertEvent( - ctx, - roomNID, - eventTypeNID, - eventStateKeyNID, - event.EventID(), - event.EventReference().EventSHA256, - authEventNIDs, - event.Depth(), - ); err != nil { - if err == sql.ErrNoRows { - // We've already inserted the event so select the numeric event ID - eventNID, stateNID, err = d.statements.selectEvent(ctx, event.EventID()) - } - if err != nil { - return 0, types.StateAtEvent{}, err - } - } - - if err = d.statements.insertEventJSON(ctx, eventNID, event.JSON()); err != nil { - return 0, types.StateAtEvent{}, err - } - - return roomNID, types.StateAtEvent{ - BeforeStateSnapshotNID: stateNID, - StateEntry: types.StateEntry{ - StateKeyTuple: types.StateKeyTuple{ - EventTypeNID: eventTypeNID, - EventStateKeyNID: eventStateKeyNID, - }, - EventNID: eventNID, - }, - }, nil -} - -func (d *Database) assignRoomNID( - ctx context.Context, txn *sql.Tx, roomID string, -) (types.RoomNID, error) { - // Check if we already have a numeric ID in the database. - roomNID, err := d.statements.selectRoomNID(ctx, txn, roomID) - if err == sql.ErrNoRows { - // We don't have a numeric ID so insert one into the database. - roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID) - if err == sql.ErrNoRows { - // We raced with another insert so run the select again. - roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID) - } - } - return roomNID, err -} - -func (d *Database) assignEventTypeNID( - ctx context.Context, eventType string, -) (types.EventTypeNID, error) { - // Check if we already have a numeric ID in the database. - eventTypeNID, err := d.statements.selectEventTypeNID(ctx, eventType) - if err == sql.ErrNoRows { - // We don't have a numeric ID so insert one into the database. - eventTypeNID, err = d.statements.insertEventTypeNID(ctx, eventType) - if err == sql.ErrNoRows { - // We raced with another insert so run the select again. - eventTypeNID, err = d.statements.selectEventTypeNID(ctx, eventType) - } - } - return eventTypeNID, err -} - -func (d *Database) assignStateKeyNID( - ctx context.Context, txn *sql.Tx, eventStateKey string, -) (types.EventStateKeyNID, error) { - // Check if we already have a numeric ID in the database. - eventStateKeyNID, err := d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) - if err == sql.ErrNoRows { - // We don't have a numeric ID so insert one into the database. - eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey) - if err == sql.ErrNoRows { - // We raced with another insert so run the select again. - eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey) - } - } - return eventStateKeyNID, err -} - -// StateEntriesForEventIDs implements input.EventDatabase -func (d *Database) StateEntriesForEventIDs( - ctx context.Context, eventIDs []string, -) ([]types.StateEntry, error) { - return d.statements.bulkSelectStateEventByID(ctx, eventIDs) -} - -// EventTypeNIDs implements state.RoomStateDatabase -func (d *Database) EventTypeNIDs( - ctx context.Context, eventTypes []string, -) (map[string]types.EventTypeNID, error) { - return d.statements.bulkSelectEventTypeNID(ctx, eventTypes) -} - -// EventStateKeyNIDs implements state.RoomStateDatabase -func (d *Database) EventStateKeyNIDs( - ctx context.Context, eventStateKeys []string, -) (map[string]types.EventStateKeyNID, error) { - return d.statements.bulkSelectEventStateKeyNID(ctx, eventStateKeys) -} - -// EventStateKeys implements query.RoomserverQueryAPIDatabase -func (d *Database) EventStateKeys( - ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, -) (map[types.EventStateKeyNID]string, error) { - return d.statements.bulkSelectEventStateKey(ctx, eventStateKeyNIDs) -} - -// EventNIDs implements query.RoomserverQueryAPIDatabase -func (d *Database) EventNIDs( - ctx context.Context, eventIDs []string, -) (map[string]types.EventNID, error) { - return d.statements.bulkSelectEventNID(ctx, eventIDs) -} - -// Events implements input.EventDatabase -func (d *Database) Events( - ctx context.Context, eventNIDs []types.EventNID, -) ([]types.Event, error) { - eventJSONs, err := d.statements.bulkSelectEventJSON(ctx, eventNIDs) +// NewPublicRoomsServerDatabase opens a database connection. +func Open(dataSourceName string) (Database, error) { + uri, err := url.Parse(dataSourceName) if err != nil { return nil, err } - results := make([]types.Event, len(eventJSONs)) - for i, eventJSON := range eventJSONs { - result := &results[i] - result.EventNID = eventJSON.EventNID - // TODO: Use NewEventFromTrustedJSON for efficiency - result.Event, err = gomatrixserverlib.NewEventFromUntrustedJSON(eventJSON.EventJSON) - if err != nil { - return nil, err - } + switch uri.Scheme { + case "postgres": + return postgres.Open(dataSourceName) + default: + return nil, errors.New("unknown schema") } - return results, nil -} - -// AddState implements input.EventDatabase -func (d *Database) AddState( - ctx context.Context, - roomNID types.RoomNID, - stateBlockNIDs []types.StateBlockNID, - state []types.StateEntry, -) (types.StateSnapshotNID, error) { - if len(state) > 0 { - stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx) - if err != nil { - return 0, err - } - if err = d.statements.bulkInsertStateData(ctx, stateBlockNID, state); err != nil { - return 0, err - } - stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) - } - - return d.statements.insertState(ctx, roomNID, stateBlockNIDs) -} - -// SetState implements input.EventDatabase -func (d *Database) SetState( - ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, -) error { - return d.statements.updateEventState(ctx, eventNID, stateNID) -} - -// StateAtEventIDs implements input.EventDatabase -func (d *Database) StateAtEventIDs( - ctx context.Context, eventIDs []string, -) ([]types.StateAtEvent, error) { - return d.statements.bulkSelectStateAtEventByID(ctx, eventIDs) -} - -// StateBlockNIDs implements state.RoomStateDatabase -func (d *Database) StateBlockNIDs( - ctx context.Context, stateNIDs []types.StateSnapshotNID, -) ([]types.StateBlockNIDList, error) { - return d.statements.bulkSelectStateBlockNIDs(ctx, stateNIDs) -} - -// StateEntries implements state.RoomStateDatabase -func (d *Database) StateEntries( - ctx context.Context, stateBlockNIDs []types.StateBlockNID, -) ([]types.StateEntryList, error) { - return d.statements.bulkSelectStateBlockEntries(ctx, stateBlockNIDs) -} - -// SnapshotNIDFromEventID implements state.RoomStateDatabase -func (d *Database) SnapshotNIDFromEventID( - ctx context.Context, eventID string, -) (types.StateSnapshotNID, error) { - _, stateNID, err := d.statements.selectEvent(ctx, eventID) - return stateNID, err -} - -// EventIDs implements input.RoomEventDatabase -func (d *Database) EventIDs( - ctx context.Context, eventNIDs []types.EventNID, -) (map[types.EventNID]string, error) { - return d.statements.bulkSelectEventID(ctx, eventNIDs) -} - -// GetLatestEventsForUpdate implements input.EventDatabase -func (d *Database) GetLatestEventsForUpdate( - ctx context.Context, roomNID types.RoomNID, -) (types.RoomRecentEventsUpdater, error) { - txn, err := d.db.Begin() - if err != nil { - return nil, err - } - eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := - d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID) - if err != nil { - txn.Rollback() // nolint: errcheck - return nil, err - } - stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs) - if err != nil { - txn.Rollback() // nolint: errcheck - return nil, err - } - var lastEventIDSent string - if lastEventNIDSent != 0 { - lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent) - if err != nil { - txn.Rollback() // nolint: errcheck - return nil, err - } - } - return &roomRecentEventsUpdater{ - transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, - }, nil -} - -// GetTransactionEventID implements input.EventDatabase -func (d *Database) GetTransactionEventID( - ctx context.Context, transactionID string, - sessionID int64, userID string, -) (string, error) { - eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, sessionID, userID) - if err == sql.ErrNoRows { - return "", nil - } - return eventID, err -} - -type roomRecentEventsUpdater struct { - transaction - d *Database - roomNID types.RoomNID - latestEvents []types.StateAtEventAndReference - lastEventIDSent string - currentStateSnapshotNID types.StateSnapshotNID -} - -// LatestEvents implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) LatestEvents() []types.StateAtEventAndReference { - return u.latestEvents -} - -// LastEventIDSent implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) LastEventIDSent() string { - return u.lastEventIDSent -} - -// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { - return u.currentStateSnapshotNID -} - -// StorePreviousEvents implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { - for _, ref := range previousEventReferences { - if err := u.d.statements.insertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { - return err - } - } - return nil -} - -// IsReferenced implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) { - err := u.d.statements.selectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256) - if err == nil { - return true, nil - } - if err == sql.ErrNoRows { - return false, nil - } - return false, err -} - -// SetLatestEvents implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) SetLatestEvents( - roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID, - currentStateSnapshotNID types.StateSnapshotNID, -) error { - eventNIDs := make([]types.EventNID, len(latest)) - for i := range latest { - eventNIDs[i] = latest[i].EventNID - } - return u.d.statements.updateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID) -} - -// HasEventBeenSent implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) { - return u.d.statements.selectEventSentToOutput(u.ctx, u.txn, eventNID) -} - -// MarkEventAsSent implements types.RoomRecentEventsUpdater -func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error { - return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID) -} - -func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) { - return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID) -} - -// RoomNID implements query.RoomserverQueryAPIDB -func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) { - roomNID, err := d.statements.selectRoomNID(ctx, nil, roomID) - if err == sql.ErrNoRows { - return 0, nil - } - return roomNID, err -} - -// LatestEventIDs implements query.RoomserverQueryAPIDatabase -func (d *Database) LatestEventIDs( - ctx context.Context, roomNID types.RoomNID, -) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) { - eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(ctx, roomNID) - if err != nil { - return nil, 0, 0, err - } - references, err := d.statements.bulkSelectEventReference(ctx, eventNIDs) - if err != nil { - return nil, 0, 0, err - } - depth, err := d.statements.selectMaxEventDepth(ctx, eventNIDs) - if err != nil { - return nil, 0, 0, err - } - return references, currentStateSnapshotNID, depth, nil -} - -// GetInvitesForUser implements query.RoomserverQueryAPIDatabase -func (d *Database) GetInvitesForUser( - ctx context.Context, - roomNID types.RoomNID, - targetUserNID types.EventStateKeyNID, -) (senderUserIDs []types.EventStateKeyNID, err error) { - return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID) -} - -// SetRoomAlias implements alias.RoomserverAliasAPIDB -func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { - return d.statements.insertRoomAlias(ctx, alias, roomID, creatorUserID) -} - -// GetRoomIDForAlias implements alias.RoomserverAliasAPIDB -func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) { - return d.statements.selectRoomIDFromAlias(ctx, alias) -} - -// GetAliasesForRoomID implements alias.RoomserverAliasAPIDB -func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) { - return d.statements.selectAliasesFromRoomID(ctx, roomID) -} - -// GetCreatorIDForAlias implements alias.RoomserverAliasAPIDB -func (d *Database) GetCreatorIDForAlias( - ctx context.Context, alias string, -) (string, error) { - return d.statements.selectCreatorIDFromAlias(ctx, alias) -} - -// RemoveRoomAlias implements alias.RoomserverAliasAPIDB -func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { - return d.statements.deleteRoomAlias(ctx, alias) -} - -// StateEntriesForTuples implements state.RoomStateDatabase -func (d *Database) StateEntriesForTuples( - ctx context.Context, - stateBlockNIDs []types.StateBlockNID, - stateKeyTuples []types.StateKeyTuple, -) ([]types.StateEntryList, error) { - return d.statements.bulkSelectFilteredStateBlockEntries( - ctx, stateBlockNIDs, stateKeyTuples, - ) -} - -// MembershipUpdater implements input.RoomEventDatabase -func (d *Database) MembershipUpdater( - ctx context.Context, roomID, targetUserID string, -) (types.MembershipUpdater, error) { - txn, err := d.db.Begin() - if err != nil { - return nil, err - } - succeeded := false - defer func() { - if !succeeded { - txn.Rollback() // nolint: errcheck - } - }() - - roomNID, err := d.assignRoomNID(ctx, txn, roomID) - if err != nil { - return nil, err - } - - targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID) - if err != nil { - return nil, err - } - - updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID) - if err != nil { - return nil, err - } - - succeeded = true - return updater, nil -} - -type membershipUpdater struct { - transaction - d *Database - roomNID types.RoomNID - targetUserNID types.EventStateKeyNID - membership membershipState -} - -func (d *Database) membershipUpdaterTxn( - ctx context.Context, - txn *sql.Tx, - roomNID types.RoomNID, - targetUserNID types.EventStateKeyNID, -) (types.MembershipUpdater, error) { - - if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil { - return nil, err - } - - membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID) - if err != nil { - return nil, err - } - - return &membershipUpdater{ - transaction{ctx, txn}, d, roomNID, targetUserNID, membership, - }, nil -} - -// IsInvite implements types.MembershipUpdater -func (u *membershipUpdater) IsInvite() bool { - return u.membership == membershipStateInvite -} - -// IsJoin implements types.MembershipUpdater -func (u *membershipUpdater) IsJoin() bool { - return u.membership == membershipStateJoin -} - -// IsLeave implements types.MembershipUpdater -func (u *membershipUpdater) IsLeave() bool { - return u.membership == membershipStateLeaveOrBan -} - -// SetToInvite implements types.MembershipUpdater -func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) - if err != nil { - return false, err - } - inserted, err := u.d.statements.insertInviteEvent( - u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), - ) - if err != nil { - return false, err - } - if u.membership != membershipStateInvite { - if err = u.d.statements.updateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0, - ); err != nil { - return false, err - } - } - return inserted, nil -} - -// SetToJoin implements types.MembershipUpdater -func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) { - var inviteEventIDs []string - - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) - if err != nil { - return nil, err - } - - // If this is a join event update, there is no invite to update - if !isUpdate { - inviteEventIDs, err = u.d.statements.updateInviteRetired( - u.ctx, u.txn, u.roomNID, u.targetUserNID, - ) - if err != nil { - return nil, err - } - } - - // Look up the NID of the new join event - nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) - if err != nil { - return nil, err - } - - if u.membership != membershipStateJoin || isUpdate { - if err = u.d.statements.updateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, - membershipStateJoin, nIDs[eventID], - ); err != nil { - return nil, err - } - } - - return inviteEventIDs, nil -} - -// SetToLeave implements types.MembershipUpdater -func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID) - if err != nil { - return nil, err - } - inviteEventIDs, err := u.d.statements.updateInviteRetired( - u.ctx, u.txn, u.roomNID, u.targetUserNID, - ) - if err != nil { - return nil, err - } - - // Look up the NID of the new leave event - nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID}) - if err != nil { - return nil, err - } - - if u.membership != membershipStateLeaveOrBan { - if err = u.d.statements.updateMembership( - u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, - membershipStateLeaveOrBan, nIDs[eventID], - ); err != nil { - return nil, err - } - } - return inviteEventIDs, nil -} - -// GetMembership implements query.RoomserverQueryAPIDB -func (d *Database) GetMembership( - ctx context.Context, roomNID types.RoomNID, requestSenderUserID string, -) (membershipEventNID types.EventNID, stillInRoom bool, err error) { - requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID) - if err != nil { - return - } - - senderMembershipEventNID, senderMembership, err := - d.statements.selectMembershipFromRoomAndTarget( - ctx, roomNID, requestSenderUserNID, - ) - if err == sql.ErrNoRows { - // The user has never been a member of that room - return 0, false, nil - } else if err != nil { - return - } - - return senderMembershipEventNID, senderMembership == membershipStateJoin, nil -} - -// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB -func (d *Database) GetMembershipEventNIDsForRoom( - ctx context.Context, roomNID types.RoomNID, joinOnly bool, -) ([]types.EventNID, error) { - if joinOnly { - return d.statements.selectMembershipsFromRoomAndMembership( - ctx, roomNID, membershipStateJoin, - ) - } - - return d.statements.selectMembershipsFromRoom(ctx, roomNID) -} - -// EventsFromIDs implements query.RoomserverQueryAPIEventDB -func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) { - nidMap, err := d.EventNIDs(ctx, eventIDs) - if err != nil { - return nil, err - } - - var nids []types.EventNID - for _, nid := range nidMap { - nids = append(nids, nid) - } - - return d.Events(ctx, nids) -} - -type transaction struct { - ctx context.Context - txn *sql.Tx -} - -// Commit implements types.Transaction -func (t *transaction) Commit() error { - return t.txn.Commit() -} - -// Rollback implements types.Transaction -func (t *transaction) Rollback() error { - return t.txn.Rollback() } diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index f0db5642..ed39cd2d 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -30,7 +30,7 @@ import ( // OutputClientDataConsumer consumes events that originated in the client API server. type OutputClientDataConsumer struct { clientAPIConsumer *common.ContinualConsumer - db *storage.SyncServerDatasource + db storage.Database notifier *sync.Notifier } @@ -39,7 +39,7 @@ func NewOutputClientDataConsumer( cfg *config.Dendrite, kafkaConsumer sarama.Consumer, n *sync.Notifier, - store *storage.SyncServerDatasource, + store storage.Database, ) *OutputClientDataConsumer { consumer := common.ContinualConsumer{ diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index e4f1ab46..cde2f508 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -33,7 +33,7 @@ import ( // OutputRoomEventConsumer consumes events that originated in the room server. type OutputRoomEventConsumer struct { roomServerConsumer *common.ContinualConsumer - db *storage.SyncServerDatasource + db storage.Database notifier *sync.Notifier query api.RoomserverQueryAPI } @@ -43,7 +43,7 @@ func NewOutputRoomEventConsumer( cfg *config.Dendrite, kafkaConsumer sarama.Consumer, n *sync.Notifier, - store *storage.SyncServerDatasource, + store storage.Database, queryAPI api.RoomserverQueryAPI, ) *OutputRoomEventConsumer { diff --git a/syncapi/consumers/typingserver.go b/syncapi/consumers/typingserver.go index 5d998a18..392f7987 100644 --- a/syncapi/consumers/typingserver.go +++ b/syncapi/consumers/typingserver.go @@ -30,7 +30,7 @@ import ( // OutputTypingEventConsumer consumes events that originated in the typing server. type OutputTypingEventConsumer struct { typingConsumer *common.ContinualConsumer - db *storage.SyncServerDatasource + db storage.Database notifier *sync.Notifier } @@ -40,7 +40,7 @@ func NewOutputTypingEventConsumer( cfg *config.Dendrite, kafkaConsumer sarama.Consumer, n *sync.Notifier, - store *storage.SyncServerDatasource, + store storage.Database, ) *OutputTypingEventConsumer { consumer := common.ContinualConsumer{ diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 0f5019fc..bd9389bd 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -34,7 +34,7 @@ const pathPrefixR0 = "/_matrix/client/r0" // Due to Setup being used to call many other functions, a gocyclo nolint is // applied: // nolint: gocyclo -func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB *storage.SyncServerDatasource, deviceDB *devices.Database) { +func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB storage.Database, deviceDB *devices.Database) { r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() authData := auth.Data{ diff --git a/syncapi/routing/state.go b/syncapi/routing/state.go index 87a93d19..61eaf218 100644 --- a/syncapi/routing/state.go +++ b/syncapi/routing/state.go @@ -40,7 +40,7 @@ type stateEventInStateResp struct { // TODO: Check if the user is in the room. If not, check if the room's history // is publicly visible. Current behaviour is returning an empty array if the // user cannot see the room's history. -func OnIncomingStateRequest(req *http.Request, db *storage.SyncServerDatasource, roomID string) util.JSONResponse { +func OnIncomingStateRequest(req *http.Request, db storage.Database, roomID string) util.JSONResponse { // TODO(#287): Auth request and handle the case where the user has left (where // we should return the state at the poin they left) @@ -87,7 +87,7 @@ func OnIncomingStateRequest(req *http.Request, db *storage.SyncServerDatasource, // /rooms/{roomID}/state/{type}/{statekey} request. It will look in current // state to see if there is an event with that type and state key, if there // is then (by default) we return the content, otherwise a 404. -func OnIncomingStateTypeRequest(req *http.Request, db *storage.SyncServerDatasource, roomID string, evType, stateKey string) util.JSONResponse { +func OnIncomingStateTypeRequest(req *http.Request, db storage.Database, roomID string, evType, stateKey string) util.JSONResponse { // TODO(#287): Auth request and handle the case where the user has left (where // we should return the state at the poin they left) diff --git a/syncapi/storage/account_data_table.go b/syncapi/storage/postgres/account_data_table.go similarity index 97% rename from syncapi/storage/account_data_table.go rename to syncapi/storage/postgres/account_data_table.go index 7b4803e3..33cfffad 100644 --- a/syncapi/storage/account_data_table.go +++ b/syncapi/storage/postgres/account_data_table.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "context" diff --git a/syncapi/storage/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go similarity index 98% rename from syncapi/storage/current_room_state_table.go rename to syncapi/storage/postgres/current_room_state_table.go index 1ab70879..dbfa111b 100644 --- a/syncapi/storage/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "context" diff --git a/syncapi/storage/filtering.go b/syncapi/storage/postgres/filtering.go similarity index 98% rename from syncapi/storage/filtering.go rename to syncapi/storage/postgres/filtering.go index 27b0b888..dcc42136 100644 --- a/syncapi/storage/filtering.go +++ b/syncapi/storage/postgres/filtering.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "strings" diff --git a/syncapi/storage/invites_table.go b/syncapi/storage/postgres/invites_table.go similarity index 84% rename from syncapi/storage/invites_table.go rename to syncapi/storage/postgres/invites_table.go index 9f52087f..ced4bfc4 100644 --- a/syncapi/storage/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -1,4 +1,19 @@ -package storage +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres import ( "context" diff --git a/syncapi/storage/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go similarity index 99% rename from syncapi/storage/output_room_events_table.go rename to syncapi/storage/postgres/output_room_events_table.go index e1803a17..3927f0c3 100644 --- a/syncapi/storage/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "context" diff --git a/syncapi/storage/syncserver.go b/syncapi/storage/postgres/syncserver.go similarity index 99% rename from syncapi/storage/syncserver.go rename to syncapi/storage/postgres/syncserver.go index cda44d2e..fc7b4e40 100644 --- a/syncapi/storage/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -1,4 +1,5 @@ -// Copyright 2017 Vector Creations Ltd +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package storage +package postgres import ( "context" diff --git a/syncapi/storage/storage.go b/syncapi/storage/storage.go new file mode 100644 index 00000000..eedb42f0 --- /dev/null +++ b/syncapi/storage/storage.go @@ -0,0 +1,63 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "errors" + "net/url" + "time" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/syncapi/storage/postgres" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/typingserver/cache" + "github.com/matrix-org/gomatrixserverlib" +) + +type Database interface { + common.PartitionStorer + AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) + Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) + WriteEvent(ctx context.Context, ev *gomatrixserverlib.Event, addStateEvents []gomatrixserverlib.Event, addStateEventIDs, removeStateEventIDs []string, transactionID *api.TransactionID) (pduPosition int64, returnErr error) + GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.Event, error) + GetStateEventsForRoom(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.FilterPart) (stateEvents []gomatrixserverlib.Event, err error) + SyncPosition(ctx context.Context) (types.SyncPosition, error) + IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.SyncPosition, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error) + CompleteSync(ctx context.Context, userID string, numRecentEventsPerRoom int) (*types.Response, error) + GetAccountDataInRange(ctx context.Context, userID string, oldPos, newPos int64, accountDataFilterPart *gomatrixserverlib.FilterPart) (map[string][]string, error) + UpsertAccountData(ctx context.Context, userID, roomID, dataType string) (int64, error) + AddInviteEvent(ctx context.Context, inviteEvent gomatrixserverlib.Event) (int64, error) + RetireInviteEvent(ctx context.Context, inviteEventID string) error + SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) + AddTypingUser(userID, roomID string, expireTime *time.Time) int64 + RemoveTypingUser(userID, roomID string) int64 +} + +// NewPublicRoomsServerDatabase opens a database connection. +func NewSyncServerDatasource(dataSourceName string) (Database, error) { + uri, err := url.Parse(dataSourceName) + if err != nil { + return nil, err + } + switch uri.Scheme { + case "postgres": + return postgres.NewSyncServerDatasource(dataSourceName) + default: + return nil, errors.New("unknown schema") + } +} diff --git a/syncapi/sync/notifier.go b/syncapi/sync/notifier.go index 15d6b070..548a17ac 100644 --- a/syncapi/sync/notifier.go +++ b/syncapi/sync/notifier.go @@ -141,7 +141,7 @@ func (n *Notifier) GetListener(req syncRequest) UserStreamListener { } // Load the membership states required to notify users correctly. -func (n *Notifier) Load(ctx context.Context, db *storage.SyncServerDatasource) error { +func (n *Notifier) Load(ctx context.Context, db storage.Database) error { roomToUsers, err := db.AllJoinedUsersInRooms(ctx) if err != nil { return err diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 94a36900..82505e68 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -31,13 +31,13 @@ import ( // RequestPool manages HTTP long-poll connections for /sync type RequestPool struct { - db *storage.SyncServerDatasource + db storage.Database accountDB *accounts.Database notifier *Notifier } // NewRequestPool makes a new RequestPool -func NewRequestPool(db *storage.SyncServerDatasource, n *Notifier, adb *accounts.Database) *RequestPool { +func NewRequestPool(db storage.Database, n *Notifier, adb *accounts.Database) *RequestPool { return &RequestPool{db, adb, n} }