Password changes (#1397)

* User API support for password changes

* Password changes in client API

* Update sytest-whitelist

* Remove debug logging

* Default logout_devices to true

* Fix deleting devices by local part
This commit is contained in:
Neil Alexander 2020-09-04 15:16:13 +01:00 committed by GitHub
parent ca8dcf46b7
commit 5076925c18
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 268 additions and 26 deletions

View file

@ -26,6 +26,7 @@ import (
type UserInternalAPI interface {
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
@ -63,6 +64,10 @@ type PerformDeviceDeletionRequest struct {
UserID string
// The devices to delete. An empty slice means delete all devices.
DeviceIDs []string
// The requesting device ID to exclude from deletion. This is needed
// so that a password change doesn't cause that client to be logged
// out. Only specify when DeviceIDs is empty.
ExceptDeviceID string
}
type PerformDeviceDeletionResponse struct {
@ -165,6 +170,18 @@ type PerformAccountCreationResponse struct {
Account *Account
}
// PerformAccountCreationRequest is the request for PerformAccountCreation
type PerformPasswordUpdateRequest struct {
Localpart string // Required: The localpart for this account.
Password string // Required: The new password to set.
}
// PerformAccountCreationResponse is the response for PerformAccountCreation
type PerformPasswordUpdateResponse struct {
PasswordUpdated bool
Account *Account
}
// PerformDeviceCreationRequest is the request for PerformDeviceCreation
type PerformDeviceCreationRequest struct {
Localpart string

View file

@ -98,6 +98,15 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
res.Account = acc
return nil
}
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
return err
}
res.PasswordUpdated = true
return nil
}
func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error {
util.GetLogger(ctx).WithFields(logrus.Fields{
"localpart": req.Localpart,
@ -126,7 +135,7 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 {
var devices []api.Device
devices, err = a.DeviceDB.RemoveAllDevices(ctx, local)
devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
for _, d := range devices {
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
}

View file

@ -30,6 +30,7 @@ const (
PerformDeviceCreationPath = "/userapi/performDeviceCreation"
PerformAccountCreationPath = "/userapi/performAccountCreation"
PerformPasswordUpdatePath = "/userapi/performPasswordUpdate"
PerformDeviceDeletionPath = "/userapi/performDeviceDeletion"
PerformDeviceUpdatePath = "/userapi/performDeviceUpdate"
@ -81,6 +82,18 @@ func (h *httpUserInternalAPI) PerformAccountCreation(
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) PerformPasswordUpdate(
ctx context.Context,
request *api.PerformPasswordUpdateRequest,
response *api.PerformPasswordUpdateResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPasswordUpdate")
defer span.Finish()
apiURL := h.apiURL + PerformPasswordUpdatePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) PerformDeviceCreation(
ctx context.Context,
request *api.PerformDeviceCreationRequest,

View file

@ -39,6 +39,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformAccountCreationPath,
httputil.MakeInternalAPI("performPasswordUpdate", func(req *http.Request) util.JSONResponse {
request := api.PerformPasswordUpdateRequest{}
response := api.PerformPasswordUpdateResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformPasswordUpdate(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformDeviceCreationPath,
httputil.MakeInternalAPI("performDeviceCreation", func(req *http.Request) util.JSONResponse {
request := api.PerformDeviceCreationRequest{}

View file

@ -28,6 +28,7 @@ type Database interface {
internal.PartitionStorer
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
SetDisplayName(ctx context.Context, localpart string, displayName string) error
// CreateAccount makes a new account with the given login name and password, and creates an empty profile

View file

@ -47,6 +47,9 @@ CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
@ -56,10 +59,9 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" +
"SELECT nextval('numeric_username_seq')"
// TODO: Update password
type accountsStatements struct {
insertAccountStmt *sql.Stmt
updatePasswordStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
@ -74,6 +76,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
return
}
if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil {
return
}
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
return
}
@ -114,6 +119,13 @@ func (s *accountsStatements) insertAccount(
}, nil
}
func (s *accountsStatements) updatePassword(
ctx context.Context, localpart, passwordHash string,
) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
return
}
func (s *accountsStatements) selectPasswordHash(
ctx context.Context, localpart string,
) (hash string, err error) {

View file

@ -112,6 +112,17 @@ func (d *Database) SetDisplayName(
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
// SetPassword sets the account password to the given hash.
func (d *Database) SetPassword(
ctx context.Context, localpart, plaintextPassword string,
) error {
hash, err := hashPassword(plaintextPassword)
if err != nil {
return err
}
return d.accounts.updatePassword(ctx, localpart, hash)
}
// CreateGuestAccount makes a new guest account and creates an empty profile
// for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) {

View file

@ -45,6 +45,9 @@ CREATE TABLE IF NOT EXISTS account_accounts (
const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
@ -54,11 +57,10 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" +
"SELECT COUNT(localpart) FROM account_accounts"
// TODO: Update password
type accountsStatements struct {
db *sql.DB
insertAccountStmt *sql.Stmt
updatePasswordStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
@ -75,6 +77,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
return
}
if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil {
return
}
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
return
}
@ -115,6 +120,13 @@ func (s *accountsStatements) insertAccount(
}, nil
}
func (s *accountsStatements) updatePassword(
ctx context.Context, localpart, passwordHash string,
) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
return
}
func (s *accountsStatements) selectPasswordHash(
ctx context.Context, localpart string,
) (hash string, err error) {

View file

@ -126,6 +126,18 @@ func (d *Database) SetDisplayName(
})
}
// SetPassword sets the account password to the given hash.
func (d *Database) SetPassword(
ctx context.Context, localpart, plaintextPassword string,
) error {
hash, err := hashPassword(plaintextPassword)
if err != nil {
return err
}
err = d.accounts.updatePassword(ctx, localpart, hash)
return err
}
// CreateGuestAccount makes a new guest account and creates an empty profile
// for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) {

View file

@ -36,5 +36,5 @@ type Database interface {
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart string) (devices []api.Device, err error)
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
}

View file

@ -70,7 +70,7 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
@ -79,7 +79,7 @@ const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1"
"DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
@ -179,10 +179,10 @@ func (s *devicesStatements) deleteDevices(
// deleteDevicesByLocalpart removes all devices for the
// given user localpart.
func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart)
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
return err
}
@ -251,10 +251,10 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
}
func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart)
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
if err != nil {
return devices, err

View file

@ -68,7 +68,7 @@ func (d *Database) GetDeviceByID(
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart)
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@ -175,14 +175,14 @@ func (d *Database) RemoveDevices(
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string,
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart)
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil

View file

@ -59,7 +59,7 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
@ -68,7 +68,7 @@ const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1"
"DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
@ -182,10 +182,10 @@ func (s *devicesStatements) deleteDevices(
}
func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart)
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
return err
}
@ -231,10 +231,10 @@ func (s *devicesStatements) selectDeviceByID(
}
func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart)
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
if err != nil {
return devices, err

View file

@ -72,7 +72,7 @@ func (d *Database) GetDeviceByID(
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart)
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@ -179,14 +179,14 @@ func (d *Database) RemoveDevices(
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string,
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart)
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil