From 5076925c184998414c3691e97fc21b554abf4a55 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 4 Sep 2020 15:16:13 +0100 Subject: [PATCH 1/2] Password changes (#1397) * User API support for password changes * Password changes in client API * Update sytest-whitelist * Remove debug logging * Default logout_devices to true * Fix deleting devices by local part --- clientapi/auth/authtypes/logintypes.go | 1 + clientapi/routing/password.go | 127 ++++++++++++++++++ clientapi/routing/routing.go | 9 ++ sytest-whitelist | 5 + userapi/api/api.go | 17 +++ userapi/internal/api.go | 11 +- userapi/inthttp/client.go | 13 ++ userapi/inthttp/server.go | 13 ++ userapi/storage/accounts/interface.go | 1 + .../accounts/postgres/accounts_table.go | 16 ++- userapi/storage/accounts/postgres/storage.go | 11 ++ .../accounts/sqlite3/accounts_table.go | 16 ++- userapi/storage/accounts/sqlite3/storage.go | 12 ++ userapi/storage/devices/interface.go | 2 +- .../storage/devices/postgres/devices_table.go | 12 +- userapi/storage/devices/postgres/storage.go | 8 +- .../storage/devices/sqlite3/devices_table.go | 12 +- userapi/storage/devices/sqlite3/storage.go | 8 +- 18 files changed, 268 insertions(+), 26 deletions(-) create mode 100644 clientapi/routing/password.go diff --git a/clientapi/auth/authtypes/logintypes.go b/clientapi/auth/authtypes/logintypes.go index 087e4504..da032425 100644 --- a/clientapi/auth/authtypes/logintypes.go +++ b/clientapi/auth/authtypes/logintypes.go @@ -5,6 +5,7 @@ type LoginType string // The relevant login types implemented in Dendrite const ( + LoginTypePassword = "m.login.password" LoginTypeDummy = "m.login.dummy" LoginTypeSharedSecret = "org.matrix.login.shared_secret" LoginTypeRecaptcha = "m.login.recaptcha" diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go new file mode 100644 index 00000000..8b81b9f0 --- /dev/null +++ b/clientapi/routing/password.go @@ -0,0 +1,127 @@ +package routing + +import ( + "net/http" + + "github.com/matrix-org/dendrite/clientapi/auth" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/config" + "github.com/matrix-org/dendrite/userapi/api" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +type newPasswordRequest struct { + NewPassword string `json:"new_password"` + LogoutDevices bool `json:"logout_devices"` + Auth newPasswordAuth `json:"auth"` +} + +type newPasswordAuth struct { + Type string `json:"type"` + Session string `json:"session"` + auth.PasswordRequest +} + +func Password( + req *http.Request, + userAPI userapi.UserInternalAPI, + accountDB accounts.Database, + device *api.Device, + cfg *config.ClientAPI, +) util.JSONResponse { + // Check that the existing password is right. + var r newPasswordRequest + r.LogoutDevices = true + + // Unmarshal the request. + resErr := httputil.UnmarshalJSONRequest(req, &r) + if resErr != nil { + return *resErr + } + + // Retrieve or generate the sessionID + sessionID := r.Auth.Session + if sessionID == "" { + // Generate a new, random session ID + sessionID = util.RandomString(sessionIDLength) + } + + // Require password auth to change the password. + if r.Auth.Type != authtypes.LoginTypePassword { + return util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: newUserInteractiveResponse( + sessionID, + []authtypes.Flow{ + { + Stages: []authtypes.LoginType{authtypes.LoginTypePassword}, + }, + }, + nil, + ), + } + } + + // Check if the existing password is correct. + typePassword := auth.LoginTypePassword{ + GetAccountByPassword: accountDB.GetAccountByPassword, + Config: cfg, + } + if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil { + return *authErr + } + AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword) + + // Check the new password strength. + if resErr = validatePassword(r.NewPassword); resErr != nil { + return *resErr + } + + // Get the local part. + localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") + return jsonerror.InternalServerError() + } + + // Ask the user API to perform the password change. + passwordReq := &userapi.PerformPasswordUpdateRequest{ + Localpart: localpart, + Password: r.NewPassword, + } + passwordRes := &userapi.PerformPasswordUpdateResponse{} + if err := userAPI.PerformPasswordUpdate(req.Context(), passwordReq, passwordRes); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("PerformPasswordUpdate failed") + return jsonerror.InternalServerError() + } + if !passwordRes.PasswordUpdated { + util.GetLogger(req.Context()).Error("Expected password to have been updated but wasn't") + return jsonerror.InternalServerError() + } + + // If the request asks us to log out all other devices then + // ask the user API to do that. + if r.LogoutDevices { + logoutReq := &userapi.PerformDeviceDeletionRequest{ + UserID: device.UserID, + DeviceIDs: nil, + ExceptDeviceID: device.ID, + } + logoutRes := &userapi.PerformDeviceDeletionResponse{} + if err := userAPI.PerformDeviceDeletion(req.Context(), logoutReq, logoutRes); err != nil { + util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") + return jsonerror.InternalServerError() + } + } + + // Return a success code. + return util.JSONResponse{ + Code: http.StatusOK, + JSON: struct{}{}, + } +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 708f6fee..b29fccf2 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -417,6 +417,15 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/account/password", + httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + if r := rateLimits.rateLimit(req); r != nil { + return *r + } + return Password(req, userAPI, accountDB, device, cfg) + }), + ).Methods(http.MethodPost, http.MethodOptions) + // Stub endpoints required by Riot r0mux.Handle("/login", diff --git a/sytest-whitelist b/sytest-whitelist index 7ce59fef..93d2de59 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -460,3 +460,8 @@ If user leaves room, remote user changes device and rejoins we see update in /sy Can search public room list Can get remote public room list Asking for a remote rooms list, but supplying the local server's name, returns the local rooms list +After changing password, can't log in with old password +After changing password, can log in with new password +After changing password, existing session still works +After changing password, different sessions can optionally be kept +After changing password, a different session no longer works by default diff --git a/userapi/api/api.go b/userapi/api/api.go index e6d05c33..3baaa100 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -26,6 +26,7 @@ import ( type UserInternalAPI interface { InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error + PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error @@ -63,6 +64,10 @@ type PerformDeviceDeletionRequest struct { UserID string // The devices to delete. An empty slice means delete all devices. DeviceIDs []string + // The requesting device ID to exclude from deletion. This is needed + // so that a password change doesn't cause that client to be logged + // out. Only specify when DeviceIDs is empty. + ExceptDeviceID string } type PerformDeviceDeletionResponse struct { @@ -165,6 +170,18 @@ type PerformAccountCreationResponse struct { Account *Account } +// PerformAccountCreationRequest is the request for PerformAccountCreation +type PerformPasswordUpdateRequest struct { + Localpart string // Required: The localpart for this account. + Password string // Required: The new password to set. +} + +// PerformAccountCreationResponse is the response for PerformAccountCreation +type PerformPasswordUpdateResponse struct { + PasswordUpdated bool + Account *Account +} + // PerformDeviceCreationRequest is the request for PerformDeviceCreation type PerformDeviceCreationRequest struct { Localpart string diff --git a/userapi/internal/api.go b/userapi/internal/api.go index b97f148e..461c548c 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -98,6 +98,15 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P res.Account = acc return nil } + +func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { + if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil { + return err + } + res.PasswordUpdated = true + return nil +} + func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error { util.GetLogger(ctx).WithFields(logrus.Fields{ "localpart": req.Localpart, @@ -126,7 +135,7 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe deletedDeviceIDs := req.DeviceIDs if len(req.DeviceIDs) == 0 { var devices []api.Device - devices, err = a.DeviceDB.RemoveAllDevices(ctx, local) + devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID) for _, d := range devices { deletedDeviceIDs = append(deletedDeviceIDs, d.ID) } diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 5f4df0eb..6dcaf756 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -30,6 +30,7 @@ const ( PerformDeviceCreationPath = "/userapi/performDeviceCreation" PerformAccountCreationPath = "/userapi/performAccountCreation" + PerformPasswordUpdatePath = "/userapi/performPasswordUpdate" PerformDeviceDeletionPath = "/userapi/performDeviceDeletion" PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" @@ -81,6 +82,18 @@ func (h *httpUserInternalAPI) PerformAccountCreation( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } +func (h *httpUserInternalAPI) PerformPasswordUpdate( + ctx context.Context, + request *api.PerformPasswordUpdateRequest, + response *api.PerformPasswordUpdateResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPasswordUpdate") + defer span.Finish() + + apiURL := h.apiURL + PerformPasswordUpdatePath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + func (h *httpUserInternalAPI) PerformDeviceCreation( ctx context.Context, request *api.PerformDeviceCreationRequest, diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 47d68ff2..d2674678 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -39,6 +39,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(PerformAccountCreationPath, + httputil.MakeInternalAPI("performPasswordUpdate", func(req *http.Request) util.JSONResponse { + request := api.PerformPasswordUpdateRequest{} + response := api.PerformPasswordUpdateResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformPasswordUpdate(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(PerformDeviceCreationPath, httputil.MakeInternalAPI("performDeviceCreation", func(req *http.Request) util.JSONResponse { request := api.PerformDeviceCreationRequest{} diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go index 86b91e60..49446f11 100644 --- a/userapi/storage/accounts/interface.go +++ b/userapi/storage/accounts/interface.go @@ -28,6 +28,7 @@ type Database interface { internal.PartitionStorer GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) + SetPassword(ctx context.Context, localpart string, plaintextPassword string) error SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error SetDisplayName(ctx context.Context, localpart string, displayName string) error // CreateAccount makes a new account with the given login name and password, and creates an empty profile diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go index 931ffb73..8c8d32cf 100644 --- a/userapi/storage/accounts/postgres/accounts_table.go +++ b/userapi/storage/accounts/postgres/accounts_table.go @@ -47,6 +47,9 @@ CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1; const insertAccountSQL = "" + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" +const updatePasswordSQL = "" + + "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" + const selectAccountByLocalpartSQL = "" + "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" @@ -56,10 +59,9 @@ const selectPasswordHashSQL = "" + const selectNewNumericLocalpartSQL = "" + "SELECT nextval('numeric_username_seq')" -// TODO: Update password - type accountsStatements struct { insertAccountStmt *sql.Stmt + updatePasswordStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt @@ -74,6 +76,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { return } + if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil { + return + } if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { return } @@ -114,6 +119,13 @@ func (s *accountsStatements) insertAccount( }, nil } +func (s *accountsStatements) updatePassword( + ctx context.Context, localpart, passwordHash string, +) (err error) { + _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) + return +} + func (s *accountsStatements) selectPasswordHash( ctx context.Context, localpart string, ) (hash string, err error) { diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index b36264dd..8b9ebef8 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -112,6 +112,17 @@ func (d *Database) SetDisplayName( return d.profiles.setDisplayName(ctx, localpart, displayName) } +// SetPassword sets the account password to the given hash. +func (d *Database) SetPassword( + ctx context.Context, localpart, plaintextPassword string, +) error { + hash, err := hashPassword(plaintextPassword) + if err != nil { + return err + } + return d.accounts.updatePassword(ctx, localpart, hash) +} + // CreateGuestAccount makes a new guest account and creates an empty profile // for this account. func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index 798a6de9..fbbdc337 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -45,6 +45,9 @@ CREATE TABLE IF NOT EXISTS account_accounts ( const insertAccountSQL = "" + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" +const updatePasswordSQL = "" + + "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" + const selectAccountByLocalpartSQL = "" + "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" @@ -54,11 +57,10 @@ const selectPasswordHashSQL = "" + const selectNewNumericLocalpartSQL = "" + "SELECT COUNT(localpart) FROM account_accounts" -// TODO: Update password - type accountsStatements struct { db *sql.DB insertAccountStmt *sql.Stmt + updatePasswordStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt @@ -75,6 +77,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { return } + if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil { + return + } if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { return } @@ -115,6 +120,13 @@ func (s *accountsStatements) insertAccount( }, nil } +func (s *accountsStatements) updatePassword( + ctx context.Context, localpart, passwordHash string, +) (err error) { + _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) + return +} + func (s *accountsStatements) selectPasswordHash( ctx context.Context, localpart string, ) (hash string, err error) { diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 46106297..4b66304c 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -126,6 +126,18 @@ func (d *Database) SetDisplayName( }) } +// SetPassword sets the account password to the given hash. +func (d *Database) SetPassword( + ctx context.Context, localpart, plaintextPassword string, +) error { + hash, err := hashPassword(plaintextPassword) + if err != nil { + return err + } + err = d.accounts.updatePassword(ctx, localpart, hash) + return err +} + // CreateGuestAccount makes a new guest account and creates an empty profile // for this account. func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go index 9b4261c9..168c84c5 100644 --- a/userapi/storage/devices/interface.go +++ b/userapi/storage/devices/interface.go @@ -36,5 +36,5 @@ type Database interface { RemoveDevice(ctx context.Context, deviceID, localpart string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. - RemoveAllDevices(ctx context.Context, localpart string) (devices []api.Device, err error) + RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) } diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/devices/postgres/devices_table.go index 282466f8..c06af754 100644 --- a/userapi/storage/devices/postgres/devices_table.go +++ b/userapi/storage/devices/postgres/devices_table.go @@ -70,7 +70,7 @@ const selectDeviceByIDSQL = "" + "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name FROM device_devices WHERE localpart = $1" + "SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2" const updateDeviceNameSQL = "" + "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" @@ -79,7 +79,7 @@ const deleteDeviceSQL = "" + "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2" const deleteDevicesSQL = "" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" @@ -179,10 +179,10 @@ func (s *devicesStatements) deleteDevices( // deleteDevicesByLocalpart removes all devices for the // given user localpart. func (s *devicesStatements) deleteDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart) + _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID) return err } @@ -251,10 +251,10 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s } func (s *devicesStatements) selectDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} - rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID) if err != nil { return devices, err diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go index 04dae986..c5bd5b6c 100644 --- a/userapi/storage/devices/postgres/storage.go +++ b/userapi/storage/devices/postgres/storage.go @@ -68,7 +68,7 @@ func (d *Database) GetDeviceByID( func (d *Database) GetDevicesByLocalpart( ctx context.Context, localpart string, ) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, nil, localpart) + return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") } func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { @@ -175,14 +175,14 @@ func (d *Database) RemoveDevices( // database matching the given user ID localpart. // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveAllDevices( - ctx context.Context, localpart string, + ctx context.Context, localpart, exceptDeviceID string, ) (devices []api.Device, err error) { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart) + devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) if err != nil { return err } - if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { + if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { return err } return nil diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index ecf43524..c75e1982 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -59,7 +59,7 @@ const selectDeviceByIDSQL = "" + "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name FROM device_devices WHERE localpart = $1" + "SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2" const updateDeviceNameSQL = "" + "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" @@ -68,7 +68,7 @@ const deleteDeviceSQL = "" + "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM device_devices WHERE localpart = $1" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2" const deleteDevicesSQL = "" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" @@ -182,10 +182,10 @@ func (s *devicesStatements) deleteDevices( } func (s *devicesStatements) deleteDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart) + _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID) return err } @@ -231,10 +231,10 @@ func (s *devicesStatements) selectDeviceByID( } func (s *devicesStatements) selectDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} - rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID) if err != nil { return devices, err diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index f775fb66..7c6645dd 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -72,7 +72,7 @@ func (d *Database) GetDeviceByID( func (d *Database) GetDevicesByLocalpart( ctx context.Context, localpart string, ) ([]api.Device, error) { - return d.devices.selectDevicesByLocalpart(ctx, nil, localpart) + return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "") } func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { @@ -179,14 +179,14 @@ func (d *Database) RemoveDevices( // database matching the given user ID localpart. // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveAllDevices( - ctx context.Context, localpart string, + ctx context.Context, localpart, exceptDeviceID string, ) (devices []api.Device, err error) { err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart) + devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) if err != nil { return err } - if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { + if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { return err } return nil From 088294ee655282a32f7a418307408291f4299565 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Fri, 4 Sep 2020 15:58:30 +0100 Subject: [PATCH 2/2] Remove QueryRoomsForUser from current state server (#1398) --- build/gobind/monolith.go | 2 +- clientapi/routing/memberships.go | 7 +++---- clientapi/routing/profile.go | 8 ++++---- clientapi/routing/routing.go | 2 +- cmd/dendrite-demo-libp2p/main.go | 2 +- cmd/dendrite-demo-yggdrasil/main.go | 2 +- cmd/dendrite-federation-sender-server/main.go | 2 +- cmd/dendrite-monolith-server/main.go | 2 +- cmd/dendritejs/main.go | 2 +- currentstateserver/api/api.go | 12 ------------ currentstateserver/internal/api.go | 9 --------- currentstateserver/inthttp/client.go | 12 ------------ currentstateserver/inthttp/server.go | 13 ------------- federationsender/consumers/keychange.go | 12 ++++++------ federationsender/federationsender.go | 4 +--- syncapi/internal/keychange_test.go | 5 ----- 16 files changed, 21 insertions(+), 75 deletions(-) diff --git a/build/gobind/monolith.go b/build/gobind/monolith.go index 725c9c07..e924ec44 100644 --- a/build/gobind/monolith.go +++ b/build/gobind/monolith.go @@ -130,7 +130,7 @@ func (m *DendriteMonolith) Start() { asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer) fsAPI := federationsender.NewInternalAPI( - base, federation, rsAPI, stateAPI, keyRing, + base, federation, rsAPI, keyRing, ) ygg.SetSessionFunc(func(address string) { diff --git a/clientapi/routing/memberships.go b/clientapi/routing/memberships.go index 56059350..61348487 100644 --- a/clientapi/routing/memberships.go +++ b/clientapi/routing/memberships.go @@ -19,7 +19,6 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/jsonerror" - currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/roomserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" @@ -94,10 +93,10 @@ func GetMemberships( func GetJoinedRooms( req *http.Request, device *userapi.Device, - stateAPI currentstateAPI.CurrentStateInternalAPI, + rsAPI api.RoomserverInternalAPI, ) util.JSONResponse { - var res currentstateAPI.QueryRoomsForUserResponse - err := stateAPI.QueryRoomsForUser(req.Context(), ¤tstateAPI.QueryRoomsForUserRequest{ + var res api.QueryRoomsForUserResponse + err := rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ UserID: device.UserID, WantMembership: "join", }, &res) diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index bc51b0b5..5dd44ca2 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -140,8 +140,8 @@ func SetAvatarURL( return jsonerror.InternalServerError() } - var res currentstateAPI.QueryRoomsForUserResponse - err = stateAPI.QueryRoomsForUser(req.Context(), ¤tstateAPI.QueryRoomsForUserRequest{ + var res api.QueryRoomsForUserResponse + err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ UserID: device.UserID, WantMembership: "join", }, &res) @@ -258,8 +258,8 @@ func SetDisplayName( return jsonerror.InternalServerError() } - var res currentstateAPI.QueryRoomsForUserResponse - err = stateAPI.QueryRoomsForUser(req.Context(), ¤tstateAPI.QueryRoomsForUserRequest{ + var res api.QueryRoomsForUserResponse + err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ UserID: device.UserID, WantMembership: "join", }, &res) diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index b29fccf2..0445852d 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -107,7 +107,7 @@ func Setup( ).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/joined_rooms", httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { - return GetJoinedRooms(req, device, stateAPI) + return GetJoinedRooms(req, device, rsAPI) }), ).Methods(http.MethodGet, http.MethodOptions) r0mux.Handle("/rooms/{roomID}/join", diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index d4f0cee0..2441e487 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -162,7 +162,7 @@ func main() { ) asAPI := appservice.NewInternalAPI(&base.Base, userAPI, rsAPI) fsAPI := federationsender.NewInternalAPI( - &base.Base, federation, rsAPI, stateAPI, keyRing, + &base.Base, federation, rsAPI, keyRing, ) rsAPI.SetFederationSenderAPI(fsAPI) provider := newPublicRoomsProvider(base.LibP2PPubsub, rsAPI, stateAPI) diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index fcf3d4c5..92a11199 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -115,7 +115,7 @@ func main() { asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer) fsAPI := federationsender.NewInternalAPI( - base, federation, rsAPI, stateAPI, keyRing, + base, federation, rsAPI, keyRing, ) ygg.SetSessionFunc(func(address string) { diff --git a/cmd/dendrite-federation-sender-server/main.go b/cmd/dendrite-federation-sender-server/main.go index 36906019..4d918f6b 100644 --- a/cmd/dendrite-federation-sender-server/main.go +++ b/cmd/dendrite-federation-sender-server/main.go @@ -31,7 +31,7 @@ func main() { rsAPI := base.RoomserverHTTPClient() fsAPI := federationsender.NewInternalAPI( - base, federation, rsAPI, base.CurrentStateAPIClient(), keyRing, + base, federation, rsAPI, keyRing, ) federationsender.AddInternalRoutes(base.InternalAPIMux, fsAPI) diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 717b21a9..643e2069 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -98,7 +98,7 @@ func main() { stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer) fsAPI := federationsender.NewInternalAPI( - base, federation, rsAPI, stateAPI, keyRing, + base, federation, rsAPI, keyRing, ) if base.UseHTTPAPIs { federationsender.AddInternalRoutes(base.InternalAPIMux, fsAPI) diff --git a/cmd/dendritejs/main.go b/cmd/dendritejs/main.go index aeca7094..839bf459 100644 --- a/cmd/dendritejs/main.go +++ b/cmd/dendritejs/main.go @@ -210,7 +210,7 @@ func main() { asQuery := appservice.NewInternalAPI( base, userAPI, rsAPI, ) - fedSenderAPI := federationsender.NewInternalAPI(base, federation, rsAPI, stateAPI, &keyRing) + fedSenderAPI := federationsender.NewInternalAPI(base, federation, rsAPI, &keyRing) rsAPI.SetFederationSenderAPI(fedSenderAPI) p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI, federation) diff --git a/currentstateserver/api/api.go b/currentstateserver/api/api.go index 2e068d95..f11422ac 100644 --- a/currentstateserver/api/api.go +++ b/currentstateserver/api/api.go @@ -21,22 +21,10 @@ import ( ) type CurrentStateInternalAPI interface { - // QueryRoomsForUser retrieves a list of room IDs matching the given query. - QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error // QueryBulkStateContent does a bulk query for state event content in the given rooms. QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error } -type QueryRoomsForUserRequest struct { - UserID string - // The desired membership of the user. If this is the empty string then no rooms are returned. - WantMembership string -} - -type QueryRoomsForUserResponse struct { - RoomIDs []string -} - type QueryBulkStateContentRequest struct { // Returns state events in these rooms RoomIDs []string diff --git a/currentstateserver/internal/api.go b/currentstateserver/internal/api.go index 22e711d6..2d6da1e6 100644 --- a/currentstateserver/internal/api.go +++ b/currentstateserver/internal/api.go @@ -26,15 +26,6 @@ type CurrentStateInternalAPI struct { DB storage.Database } -func (a *CurrentStateInternalAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error { - roomIDs, err := a.DB.GetRoomsByMembership(ctx, req.UserID, req.WantMembership) - if err != nil { - return err - } - res.RoomIDs = roomIDs - return nil -} - func (a *CurrentStateInternalAPI) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error { events, err := a.DB.GetBulkStateContent(ctx, req.RoomIDs, req.StateTuples, req.AllowWildcards) if err != nil { diff --git a/currentstateserver/inthttp/client.go b/currentstateserver/inthttp/client.go index bb853839..13130716 100644 --- a/currentstateserver/inthttp/client.go +++ b/currentstateserver/inthttp/client.go @@ -50,18 +50,6 @@ type httpCurrentStateInternalAPI struct { httpClient *http.Client } -func (h *httpCurrentStateInternalAPI) QueryRoomsForUser( - ctx context.Context, - request *api.QueryRoomsForUserRequest, - response *api.QueryRoomsForUserResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomsForUser") - defer span.Finish() - - apiURL := h.apiURL + QueryRoomsForUserPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - func (h *httpCurrentStateInternalAPI) QueryBulkStateContent( ctx context.Context, request *api.QueryBulkStateContentRequest, diff --git a/currentstateserver/inthttp/server.go b/currentstateserver/inthttp/server.go index 9b1f511a..70d6ecfd 100644 --- a/currentstateserver/inthttp/server.go +++ b/currentstateserver/inthttp/server.go @@ -25,19 +25,6 @@ import ( ) func AddRoutes(internalAPIMux *mux.Router, intAPI api.CurrentStateInternalAPI) { - internalAPIMux.Handle(QueryRoomsForUserPath, - httputil.MakeInternalAPI("queryRoomsForUser", func(req *http.Request) util.JSONResponse { - request := api.QueryRoomsForUserRequest{} - response := api.QueryRoomsForUserResponse{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.MessageResponse(http.StatusBadRequest, err.Error()) - } - if err := intAPI.QueryRoomsForUser(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) internalAPIMux.Handle(QueryBulkStateContentPath, httputil.MakeInternalAPI("queryBulkStateContent", func(req *http.Request) util.JSONResponse { request := api.QueryBulkStateContentRequest{} diff --git a/federationsender/consumers/keychange.go b/federationsender/consumers/keychange.go index 4f206f5f..28244e92 100644 --- a/federationsender/consumers/keychange.go +++ b/federationsender/consumers/keychange.go @@ -20,12 +20,12 @@ import ( "fmt" "github.com/Shopify/sarama" - stateapi "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/federationsender/queue" "github.com/matrix-org/dendrite/federationsender/storage" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/keyserver/api" + roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -36,7 +36,7 @@ type KeyChangeConsumer struct { db storage.Database queues *queue.OutgoingQueues serverName gomatrixserverlib.ServerName - stateAPI stateapi.CurrentStateInternalAPI + rsAPI roomserverAPI.RoomserverInternalAPI } // NewKeyChangeConsumer creates a new KeyChangeConsumer. Call Start() to begin consuming from key servers. @@ -45,7 +45,7 @@ func NewKeyChangeConsumer( kafkaConsumer sarama.Consumer, queues *queue.OutgoingQueues, store storage.Database, - stateAPI stateapi.CurrentStateInternalAPI, + rsAPI roomserverAPI.RoomserverInternalAPI, ) *KeyChangeConsumer { c := &KeyChangeConsumer{ consumer: &internal.ContinualConsumer{ @@ -57,7 +57,7 @@ func NewKeyChangeConsumer( queues: queues, db: store, serverName: cfg.Matrix.ServerName, - stateAPI: stateAPI, + rsAPI: rsAPI, } c.consumer.ProcessMessage = c.onMessage @@ -92,8 +92,8 @@ func (t *KeyChangeConsumer) onMessage(msg *sarama.ConsumerMessage) error { return nil } - var queryRes stateapi.QueryRoomsForUserResponse - err = t.stateAPI.QueryRoomsForUser(context.Background(), &stateapi.QueryRoomsForUserRequest{ + var queryRes roomserverAPI.QueryRoomsForUserResponse + err = t.rsAPI.QueryRoomsForUser(context.Background(), &roomserverAPI.QueryRoomsForUserRequest{ UserID: m.UserID, WantMembership: "join", }, &queryRes) diff --git a/federationsender/federationsender.go b/federationsender/federationsender.go index 0e2213da..2f122328 100644 --- a/federationsender/federationsender.go +++ b/federationsender/federationsender.go @@ -16,7 +16,6 @@ package federationsender import ( "github.com/gorilla/mux" - stateapi "github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/federationsender/consumers" "github.com/matrix-org/dendrite/federationsender/internal" @@ -42,7 +41,6 @@ func NewInternalAPI( base *setup.BaseDendrite, federation *gomatrixserverlib.FederationClient, rsAPI roomserverAPI.RoomserverInternalAPI, - stateAPI stateapi.CurrentStateInternalAPI, keyRing *gomatrixserverlib.KeyRing, ) api.FederationSenderInternalAPI { cfg := &base.Cfg.FederationSender @@ -82,7 +80,7 @@ func NewInternalAPI( logrus.WithError(err).Panic("failed to start typing server consumer") } keyConsumer := consumers.NewKeyChangeConsumer( - &base.Cfg.KeyServer, base.KafkaConsumer, queues, federationSenderDB, stateAPI, + &base.Cfg.KeyServer, base.KafkaConsumer, queues, federationSenderDB, rsAPI, ) if err := keyConsumer.Start(); err != nil { logrus.WithError(err).Panic("failed to start key server consumer") diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index ee6ca1f4..7d739430 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -112,11 +112,6 @@ type mockStateAPI struct { rsAPI *mockRoomserverAPI } -// QueryRoomsForUser retrieves a list of room IDs matching the given query. -func (s *mockStateAPI) QueryRoomsForUser(ctx context.Context, req *stateapi.QueryRoomsForUserRequest, res *stateapi.QueryRoomsForUserResponse) error { - return nil -} - // QueryBulkStateContent does a bulk query for state event content in the given rooms. func (s *mockStateAPI) QueryBulkStateContent(ctx context.Context, req *stateapi.QueryBulkStateContentRequest, res *stateapi.QueryBulkStateContentResponse) error { var res2 api.QueryBulkStateContentResponse