From 9af2f5f1f253a821cec660ef477c274d5cd13953 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Thu, 27 Aug 2020 18:53:40 +0100 Subject: [PATCH] Remove device DB from clientapi (#1352) * Remove device DB from clientapi * Remove device DB from startup configuration It's all an impl detail now in user API --- build/gobind/monolith.go | 4 +- clientapi/clientapi.go | 4 +- clientapi/routing/device.go | 50 +++++++++---------- clientapi/routing/logout.go | 33 ++++++------ clientapi/routing/routing.go | 10 ++-- cmd/dendrite-client-api-server/main.go | 3 +- cmd/dendrite-demo-libp2p/main.go | 4 +- cmd/dendrite-demo-yggdrasil/main.go | 4 +- cmd/dendrite-monolith-server/main.go | 4 +- cmd/dendrite-user-api-server/main.go | 3 +- cmd/dendritejs/main.go | 4 +- internal/setup/base.go | 12 ----- internal/setup/monolith.go | 4 +- userapi/api/api.go | 5 +- userapi/internal/api.go | 13 ++++- userapi/storage/devices/interface.go | 3 +- .../storage/devices/postgres/devices_table.go | 5 +- userapi/storage/devices/postgres/storage.go | 11 ++-- .../storage/devices/sqlite3/devices_table.go | 5 +- userapi/storage/devices/sqlite3/storage.go | 11 ++-- userapi/userapi.go | 14 ++++-- userapi/userapi_test.go | 21 ++++---- 22 files changed, 109 insertions(+), 118 deletions(-) diff --git a/build/gobind/monolith.go b/build/gobind/monolith.go index 2ea09f63..59535c7b 100644 --- a/build/gobind/monolith.go +++ b/build/gobind/monolith.go @@ -111,13 +111,12 @@ func (m *DendriteMonolith) Start() { defer base.Close() // nolint: errcheck accountDB := base.CreateAccountsDB() - deviceDB := base.CreateDeviceDB() federation := ygg.CreateFederationClient(base) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, base.KafkaProducer) - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices, keyAPI) + userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) keyAPI.SetUserAPI(userAPI) rsAPI := roomserver.NewInternalAPI( @@ -153,7 +152,6 @@ func (m *DendriteMonolith) Start() { monolith := setup.Monolith{ Config: base.Cfg, AccountDB: accountDB, - DeviceDB: deviceDB, Client: ygg.CreateClient(base), FedClient: federation, KeyRing: keyRing, diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index 1a4307c1..fe6789fc 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -30,7 +30,6 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" ) @@ -39,7 +38,6 @@ func AddPublicRoutes( router *mux.Router, cfg *config.ClientAPI, producer sarama.SyncProducer, - deviceDB devices.Database, accountsDB accounts.Database, federation *gomatrixserverlib.FederationClient, rsAPI roomserverAPI.RoomserverInternalAPI, @@ -59,7 +57,7 @@ func AddPublicRoutes( routing.Setup( router, cfg, eduInputAPI, rsAPI, asAPI, - accountsDB, deviceDB, userAPI, federation, + accountsDB, userAPI, federation, syncProducer, transactionsCache, fsAPI, stateAPI, keyAPI, extRoomsProvider, ) } diff --git a/clientapi/routing/device.go b/clientapi/routing/device.go index d0b3bdbe..56886d57 100644 --- a/clientapi/routing/device.go +++ b/clientapi/routing/device.go @@ -15,7 +15,6 @@ package routing import ( - "database/sql" "encoding/json" "io/ioutil" "net/http" @@ -23,7 +22,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/devices" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -50,57 +49,56 @@ type devicesDeleteJSON struct { // GetDeviceByID handles /devices/{deviceID} func GetDeviceByID( - req *http.Request, deviceDB devices.Database, device *api.Device, + req *http.Request, userAPI userapi.UserInternalAPI, device *api.Device, deviceID string, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + var queryRes userapi.QueryDevicesResponse + err := userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{ + UserID: device.UserID, + }, &queryRes) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") + util.GetLogger(req.Context()).WithError(err).Error("QueryDevices failed") return jsonerror.InternalServerError() } - - ctx := req.Context() - dev, err := deviceDB.GetDeviceByID(ctx, localpart, deviceID) - if err == sql.ErrNoRows { + var targetDevice *userapi.Device + for _, device := range queryRes.Devices { + if device.ID == deviceID { + targetDevice = &device + break + } + } + if targetDevice == nil { return util.JSONResponse{ Code: http.StatusNotFound, JSON: jsonerror.NotFound("Unknown device"), } - } else if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDeviceByID failed") - return jsonerror.InternalServerError() } return util.JSONResponse{ Code: http.StatusOK, JSON: deviceJSON{ - DeviceID: dev.ID, - DisplayName: dev.DisplayName, + DeviceID: targetDevice.ID, + DisplayName: targetDevice.DisplayName, }, } } // GetDevicesByLocalpart handles /devices func GetDevicesByLocalpart( - req *http.Request, deviceDB devices.Database, device *api.Device, + req *http.Request, userAPI userapi.UserInternalAPI, device *api.Device, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + var queryRes userapi.QueryDevicesResponse + err := userAPI.QueryDevices(req.Context(), &userapi.QueryDevicesRequest{ + UserID: device.UserID, + }, &queryRes) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - - ctx := req.Context() - deviceList, err := deviceDB.GetDevicesByLocalpart(ctx, localpart) - - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.GetDevicesByLocalpart failed") + util.GetLogger(req.Context()).WithError(err).Error("QueryDevices failed") return jsonerror.InternalServerError() } res := devicesJSON{} - for _, dev := range deviceList { + for _, dev := range queryRes.Devices { res.Devices = append(res.Devices, deviceJSON{ DeviceID: dev.ID, DisplayName: dev.DisplayName, diff --git a/clientapi/routing/logout.go b/clientapi/routing/logout.go index 3ce47169..cb300e9f 100644 --- a/clientapi/routing/logout.go +++ b/clientapi/routing/logout.go @@ -19,23 +19,21 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/dendrite/userapi/storage/devices" - "github.com/matrix-org/gomatrixserverlib" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" ) // Logout handles POST /logout func Logout( - req *http.Request, deviceDB devices.Database, device *api.Device, + req *http.Request, userAPI userapi.UserInternalAPI, device *api.Device, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + var performRes userapi.PerformDeviceDeletionResponse + err := userAPI.PerformDeviceDeletion(req.Context(), &userapi.PerformDeviceDeletionRequest{ + UserID: device.UserID, + DeviceIDs: []string{device.ID}, + }, &performRes) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - - if err := deviceDB.RemoveDevice(req.Context(), device.ID, localpart); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.RemoveDevice failed") + util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") return jsonerror.InternalServerError() } @@ -47,16 +45,15 @@ func Logout( // LogoutAll handles POST /logout/all func LogoutAll( - req *http.Request, deviceDB devices.Database, device *api.Device, + req *http.Request, userAPI userapi.UserInternalAPI, device *api.Device, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + var performRes userapi.PerformDeviceDeletionResponse + err := userAPI.PerformDeviceDeletion(req.Context(), &userapi.PerformDeviceDeletionRequest{ + UserID: device.UserID, + DeviceIDs: nil, + }, &performRes) if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return jsonerror.InternalServerError() - } - - if err := deviceDB.RemoveAllDevices(req.Context(), localpart); err != nil { - util.GetLogger(req.Context()).WithError(err).Error("deviceDB.RemoveAllDevices failed") + util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index c259e529..f2494dc7 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -35,7 +35,6 @@ import ( roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -52,7 +51,6 @@ func Setup( rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI, accountDB accounts.Database, - deviceDB devices.Database, userAPI userapi.UserInternalAPI, federation *gomatrixserverlib.FederationClient, syncProducer *producers.SyncAPIProducer, @@ -322,13 +320,13 @@ func Setup( r0mux.Handle("/logout", httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return Logout(req, deviceDB, device) + return Logout(req, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/logout/all", httputil.MakeAuthAPI("logout", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return LogoutAll(req, deviceDB, device) + return LogoutAll(req, userAPI, device) }), ).Methods(http.MethodPost, http.MethodOptions) @@ -632,7 +630,7 @@ func Setup( r0mux.Handle("/devices", httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return GetDevicesByLocalpart(req, deviceDB, device) + return GetDevicesByLocalpart(req, userAPI, device) }), ).Methods(http.MethodGet, http.MethodOptions) @@ -642,7 +640,7 @@ func Setup( if err != nil { return util.ErrorResponse(err) } - return GetDeviceByID(req, deviceDB, device, vars["deviceID"]) + return GetDeviceByID(req, userAPI, device, vars["deviceID"]) }), ).Methods(http.MethodGet, http.MethodOptions) diff --git a/cmd/dendrite-client-api-server/main.go b/cmd/dendrite-client-api-server/main.go index 4961b34e..35dbb774 100644 --- a/cmd/dendrite-client-api-server/main.go +++ b/cmd/dendrite-client-api-server/main.go @@ -27,7 +27,6 @@ func main() { defer base.Close() // nolint: errcheck accountDB := base.CreateAccountsDB() - deviceDB := base.CreateDeviceDB() federation := base.CreateFederationClient() asQuery := base.AppserviceHTTPClient() @@ -39,7 +38,7 @@ func main() { keyAPI := base.KeyServerHTTPClient() clientapi.AddPublicRoutes( - base.PublicClientAPIMux, &base.Cfg.ClientAPI, base.KafkaProducer, deviceDB, accountDB, federation, + base.PublicClientAPIMux, &base.Cfg.ClientAPI, base.KafkaProducer, accountDB, federation, rsAPI, eduInputAPI, asQuery, stateAPI, transactions.New(), fsAPI, userAPI, keyAPI, nil, ) diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index 4934fe5f..e2d23e89 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -140,10 +140,9 @@ func main() { defer base.Base.Close() // nolint: errcheck accountDB := base.Base.CreateAccountsDB() - deviceDB := base.Base.CreateDeviceDB() federation := createFederationClient(base) keyAPI := keyserver.NewInternalAPI(&base.Base.Cfg.KeyServer, federation, base.Base.KafkaProducer) - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil, keyAPI) + userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) keyAPI.SetUserAPI(userAPI) serverKeyAPI := serverkeyapi.NewInternalAPI( @@ -175,7 +174,6 @@ func main() { monolith := setup.Monolith{ Config: base.Base.Cfg, AccountDB: accountDB, - DeviceDB: deviceDB, Client: createClient(base), FedClient: federation, KeyRing: keyRing, diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index e8745b3e..26999ebe 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -94,14 +94,13 @@ func main() { defer base.Close() // nolint: errcheck accountDB := base.CreateAccountsDB() - deviceDB := base.CreateDeviceDB() federation := ygg.CreateFederationClient(base) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, base.KafkaProducer) - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil, keyAPI) + userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) keyAPI.SetUserAPI(userAPI) rsComponent := roomserver.NewInternalAPI( @@ -136,7 +135,6 @@ func main() { monolith := setup.Monolith{ Config: base.Cfg, AccountDB: accountDB, - DeviceDB: deviceDB, Client: ygg.CreateClient(base), FedClient: federation, KeyRing: keyRing, diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index e2d2de48..81511746 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -69,7 +69,6 @@ func main() { defer base.Close() // nolint: errcheck accountDB := base.CreateAccountsDB() - deviceDB := base.CreateDeviceDB() federation := base.CreateFederationClient() serverKeyAPI := serverkeyapi.NewInternalAPI( @@ -110,7 +109,7 @@ func main() { rsImpl.SetFederationSenderAPI(fsAPI) keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, fsAPI, base.KafkaProducer) - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices, keyAPI) + userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) keyAPI.SetUserAPI(userAPI) eduInputAPI := eduserver.NewInternalAPI( @@ -130,7 +129,6 @@ func main() { monolith := setup.Monolith{ Config: base.Cfg, AccountDB: accountDB, - DeviceDB: deviceDB, Client: gomatrixserverlib.NewClient(cfg.FederationSender.DisableTLSValidation), FedClient: federation, KeyRing: keyRing, diff --git a/cmd/dendrite-user-api-server/main.go b/cmd/dendrite-user-api-server/main.go index c21525e6..c8e2e2a3 100644 --- a/cmd/dendrite-user-api-server/main.go +++ b/cmd/dendrite-user-api-server/main.go @@ -25,9 +25,8 @@ func main() { defer base.Close() // nolint: errcheck accountDB := base.CreateAccountsDB() - deviceDB := base.CreateDeviceDB() - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, cfg.Derived.ApplicationServices, base.KeyServerHTTPClient()) + userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, base.KeyServerHTTPClient()) userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index ceb252d8..c95eb3fc 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -191,10 +191,9 @@ func main() { defer base.Close() // nolint: errcheck accountDB := base.CreateAccountsDB() - deviceDB := base.CreateDeviceDB() federation := createFederationClient(cfg, node) keyAPI := keyserver.NewInternalAPI(&base.Cfg.KeyServer, federation, base.KafkaProducer) - userAPI := userapi.NewInternalAPI(accountDB, deviceDB, cfg.Global.ServerName, nil, keyAPI) + userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) keyAPI.SetUserAPI(userAPI) fetcher := &libp2pKeyFetcher{} @@ -218,7 +217,6 @@ func main() { monolith := setup.Monolith{ Config: base.Cfg, AccountDB: accountDB, - DeviceDB: deviceDB, Client: createClient(node), FedClient: federation, KeyRing: &keyRing, diff --git a/internal/setup/base.go b/internal/setup/base.go index fc408311..7bf06e74 100644 --- a/internal/setup/base.go +++ b/internal/setup/base.go @@ -32,7 +32,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/Shopify/sarama" "github.com/gorilla/mux" @@ -237,17 +236,6 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI { return f } -// CreateDeviceDB creates a new instance of the device database. Should only be -// called once per component. -func (b *BaseDendrite) CreateDeviceDB() devices.Database { - db, err := devices.NewDatabase(&b.Cfg.UserAPI.DeviceDatabase, b.Cfg.Global.ServerName) - if err != nil { - logrus.WithError(err).Panicf("failed to connect to devices db") - } - - return db -} - // CreateAccountsDB creates a new instance of the accounts database. Should only // be called once per component. func (b *BaseDendrite) CreateAccountsDB() accounts.Database { diff --git a/internal/setup/monolith.go b/internal/setup/monolith.go index 5e6c8fcf..f79ebae4 100644 --- a/internal/setup/monolith.go +++ b/internal/setup/monolith.go @@ -33,7 +33,6 @@ import ( "github.com/matrix-org/dendrite/syncapi" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" ) @@ -41,7 +40,6 @@ import ( // all components of Dendrite, for use in monolith mode. type Monolith struct { Config *config.Dendrite - DeviceDB devices.Database AccountDB accounts.Database KeyRing *gomatrixserverlib.KeyRing Client *gomatrixserverlib.Client @@ -65,7 +63,7 @@ type Monolith struct { // AddAllPublicRoutes attaches all public paths to the given router func (m *Monolith) AddAllPublicRoutes(csMux, ssMux, keyMux, mediaMux *mux.Router) { clientapi.AddPublicRoutes( - csMux, &m.Config.ClientAPI, m.KafkaProducer, m.DeviceDB, m.AccountDB, + csMux, &m.Config.ClientAPI, m.KafkaProducer, m.AccountDB, m.FedClient, m.RoomserverAPI, m.EDUInternalAPI, m.AppserviceAPI, m.StateAPI, transactions.New(), m.FederationSenderAPI, m.UserAPI, m.KeyAPI, m.ExtPublicRoomsProvider, diff --git a/userapi/api/api.go b/userapi/api/api.go index 84338dbf..e6d05c33 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -61,7 +61,7 @@ type PerformDeviceUpdateResponse struct { type PerformDeviceDeletionRequest struct { UserID string - // The devices to delete + // The devices to delete. An empty slice means delete all devices. DeviceIDs []string } @@ -192,8 +192,7 @@ type Device struct { // The unique ID of the session identified by the access token. // Can be used as a secure substitution in places where data needs to be // associated with access tokens. - SessionID int64 - // TODO: display name, last used timestamp, keys, etc + SessionID int64 DisplayName string } diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 05cecc1b..b97f148e 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -123,12 +123,21 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe if domain != a.ServerName { return fmt.Errorf("cannot PerformDeviceDeletion of remote users: got %s want %s", domain, a.ServerName) } - err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs) + deletedDeviceIDs := req.DeviceIDs + if len(req.DeviceIDs) == 0 { + var devices []api.Device + devices, err = a.DeviceDB.RemoveAllDevices(ctx, local) + for _, d := range devices { + deletedDeviceIDs = append(deletedDeviceIDs, d.ID) + } + } else { + err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs) + } if err != nil { return err } // create empty device keys and upload them to delete what was once there and trigger device list changes - return a.deviceListUpdate(req.UserID, req.DeviceIDs) + return a.deviceListUpdate(req.UserID, deletedDeviceIDs) } func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error { diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go index 3c9ec934..9b4261c9 100644 --- a/userapi/storage/devices/interface.go +++ b/userapi/storage/devices/interface.go @@ -35,5 +35,6 @@ type Database interface { UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error RemoveDevice(ctx context.Context, deviceID, localpart string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error - RemoveAllDevices(ctx context.Context, localpart string) error + // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. + RemoveAllDevices(ctx context.Context, localpart string) (devices []api.Device, err error) } diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/devices/postgres/devices_table.go index 03bf7c72..282466f8 100644 --- a/userapi/storage/devices/postgres/devices_table.go +++ b/userapi/storage/devices/postgres/devices_table.go @@ -251,11 +251,10 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s } func (s *devicesStatements) selectDevicesByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, txn *sql.Tx, localpart string, ) ([]api.Device, error) { devices := []api.Device{} - - rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart) if err != nil { return devices, err diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go index 4a7c7f97..04dae986 100644 --- a/userapi/storage/devices/postgres/storage.go +++ b/userapi/storage/devices/postgres/storage.go @@ -68,7 +68,7 @@ func (d *Database) GetDeviceByID( func (d *Database) GetDevicesByLocalpart( ctx context.Context, localpart string, ) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, localpart) + return d.devices.selectDevicesByLocalpart(ctx, nil, localpart) } func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { @@ -176,11 +176,16 @@ func (d *Database) RemoveDevices( // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveAllDevices( ctx context.Context, localpart string, -) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { +) (devices []api.Device, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart) + if err != nil { + return err + } if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { return err } return nil }) + return } diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index c93e8b77..ecf43524 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -231,11 +231,10 @@ func (s *devicesStatements) selectDeviceByID( } func (s *devicesStatements) selectDevicesByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, txn *sql.Tx, localpart string, ) ([]api.Device, error) { devices := []api.Device{} - - rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart) if err != nil { return devices, err diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index 4f426c6e..f775fb66 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -72,7 +72,7 @@ func (d *Database) GetDeviceByID( func (d *Database) GetDevicesByLocalpart( ctx context.Context, localpart string, ) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, localpart) + return d.devices.selectDevicesByLocalpart(ctx, nil, localpart) } func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { @@ -180,11 +180,16 @@ func (d *Database) RemoveDevices( // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveAllDevices( ctx context.Context, localpart string, -) error { - return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { +) (devices []api.Device, err error) { + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart) + if err != nil { + return err + } if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { return err } return nil }) + return } diff --git a/userapi/userapi.go b/userapi/userapi.go index c4ab90ba..13249142 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -23,7 +23,7 @@ import ( "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/dendrite/userapi/storage/devices" - "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" ) // AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions @@ -34,13 +34,19 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. -func NewInternalAPI(accountDB accounts.Database, deviceDB devices.Database, - serverName gomatrixserverlib.ServerName, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI) api.UserInternalAPI { +func NewInternalAPI( + accountDB accounts.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, +) api.UserInternalAPI { + + deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName) + if err != nil { + logrus.WithError(err).Panicf("failed to connect to device db") + } return &internal.UserInternalAPI{ AccountDB: accountDB, DeviceDB: deviceDB, - ServerName: serverName, + ServerName: cfg.Matrix.ServerName, AppServices: appServices, KeyAPI: keyAPI, } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 548148f2..3fc97d06 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -15,7 +15,6 @@ import ( "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/storage/accounts" - "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" ) @@ -23,27 +22,31 @@ const ( serverName = gomatrixserverlib.ServerName("example.com") ) -func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database, devices.Database) { +func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database) { accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{ ConnectionString: "file::memory:", }, serverName) if err != nil { t.Fatalf("failed to create account DB: %s", err) } - deviceDB, err := devices.NewDatabase(&config.DatabaseOptions{ - ConnectionString: "file::memory:", - }, serverName) - if err != nil { - t.Fatalf("failed to create device DB: %s", err) + cfg := &config.UserAPI{ + DeviceDatabase: config.DatabaseOptions{ + ConnectionString: "file::memory:", + MaxOpenConnections: 1, + MaxIdleConnections: 1, + }, + Matrix: &config.Global{ + ServerName: serverName, + }, } - return userapi.NewInternalAPI(accountDB, deviceDB, serverName, nil, nil), accountDB, deviceDB + return userapi.NewInternalAPI(accountDB, cfg, nil, nil), accountDB } func TestQueryProfile(t *testing.T) { aliceAvatarURL := "mxc://example.com/alice" aliceDisplayName := "Alice" - userAPI, accountDB, _ := MustMakeInternalAPI(t) + userAPI, accountDB := MustMakeInternalAPI(t) _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "") if err != nil { t.Fatalf("failed to make account: %s", err)