mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-26 15:08:28 +00:00
Correctly create new device when device_id is passed to /login (#753)
Fixes https://github.com/matrix-org/dendrite/issues/401 Currently when passing a `device_id` parameter to `/login`, which is [supposed](https://matrix.org/docs/spec/client_server/unstable#post-matrix-client-r0-login) to return a device with that ID set, it instead just generates a random `device_id` and hands that back to you. The code was already there to do this correctly, it looks like it had just been broken during some change. Hopefully sytest will prevent this from becoming broken again.
This commit is contained in:
parent
bdd1a87d4d
commit
78032b3f4c
5 changed files with 34 additions and 25 deletions
|
@ -169,6 +169,8 @@ func (s *devicesStatements) selectDeviceByToken(
|
|||
return &dev, err
|
||||
}
|
||||
|
||||
// selectDeviceByID retrieves a device from the database with the given user
|
||||
// localpart and deviceID
|
||||
func (s *devicesStatements) selectDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
) (*authtypes.Device, error) {
|
||||
|
|
|
@ -84,7 +84,7 @@ func (d *Database) CreateDevice(
|
|||
if deviceID != nil {
|
||||
returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
// Revoke existing token for this device
|
||||
// Revoke existing tokens for this device
|
||||
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -18,7 +18,6 @@ import (
|
|||
"net/http"
|
||||
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth"
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
|
@ -42,10 +41,12 @@ type flow struct {
|
|||
}
|
||||
|
||||
type passwordRequest struct {
|
||||
User string `json:"user"`
|
||||
Password string `json:"password"`
|
||||
User string `json:"user"`
|
||||
Password string `json:"password"`
|
||||
// Both DeviceID and InitialDisplayName can be omitted, or empty strings ("")
|
||||
// Thus a pointer is needed to differentiate between the two
|
||||
InitialDisplayName *string `json:"initial_device_display_name"`
|
||||
DeviceID string `json:"device_id"`
|
||||
DeviceID *string `json:"device_id"`
|
||||
}
|
||||
|
||||
type loginResponse struct {
|
||||
|
@ -110,7 +111,7 @@ func Login(
|
|||
return httputil.LogThenError(req, err)
|
||||
}
|
||||
|
||||
dev, err := getDevice(req.Context(), r, deviceDB, acc, localpart, token)
|
||||
dev, err := getDevice(req.Context(), r, deviceDB, acc, token)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
|
@ -134,20 +135,16 @@ func Login(
|
|||
}
|
||||
}
|
||||
|
||||
// check if device exists else create one
|
||||
// getDevice returns a new or existing device
|
||||
func getDevice(
|
||||
ctx context.Context,
|
||||
r passwordRequest,
|
||||
deviceDB *devices.Database,
|
||||
acc *authtypes.Account,
|
||||
localpart, token string,
|
||||
token string,
|
||||
) (dev *authtypes.Device, err error) {
|
||||
dev, err = deviceDB.GetDeviceByID(ctx, localpart, r.DeviceID)
|
||||
if err == sql.ErrNoRows {
|
||||
// device doesn't exist, create one
|
||||
dev, err = deviceDB.CreateDevice(
|
||||
ctx, acc.Localpart, nil, token, r.InitialDisplayName,
|
||||
)
|
||||
}
|
||||
dev, err = deviceDB.CreateDevice(
|
||||
ctx, acc.Localpart, r.DeviceID, token, r.InitialDisplayName,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -121,7 +121,10 @@ type registerRequest struct {
|
|||
// user-interactive auth params
|
||||
Auth authDict `json:"auth"`
|
||||
|
||||
// Both DeviceID and InitialDisplayName can be omitted, or empty strings ("")
|
||||
// Thus a pointer is needed to differentiate between the two
|
||||
InitialDisplayName *string `json:"initial_device_display_name"`
|
||||
DeviceID *string `json:"device_id"`
|
||||
|
||||
// Prevent this user from logging in
|
||||
InhibitLogin common.WeakBoolean `json:"inhibit_login"`
|
||||
|
@ -626,7 +629,7 @@ func handleApplicationServiceRegistration(
|
|||
// application service registration is entirely separate.
|
||||
return completeRegistration(
|
||||
req.Context(), accountDB, deviceDB, r.Username, "", appserviceID,
|
||||
r.InhibitLogin, r.InitialDisplayName,
|
||||
r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -646,7 +649,7 @@ func checkAndCompleteFlow(
|
|||
// This flow was completed, registration can continue
|
||||
return completeRegistration(
|
||||
req.Context(), accountDB, deviceDB, r.Username, r.Password, "",
|
||||
r.InhibitLogin, r.InitialDisplayName,
|
||||
r.InhibitLogin, r.InitialDisplayName, r.DeviceID,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -697,10 +700,10 @@ func LegacyRegister(
|
|||
return util.MessageResponse(http.StatusForbidden, "HMAC incorrect")
|
||||
}
|
||||
|
||||
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil)
|
||||
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil, nil)
|
||||
case authtypes.LoginTypeDummy:
|
||||
// there is nothing to do
|
||||
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil)
|
||||
return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil, nil)
|
||||
default:
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusNotImplemented,
|
||||
|
@ -738,13 +741,19 @@ func parseAndValidateLegacyLogin(req *http.Request, r *legacyRegisterRequest) *u
|
|||
return nil
|
||||
}
|
||||
|
||||
// completeRegistration runs some rudimentary checks against the submitted
|
||||
// input, then if successful creates an account and a newly associated device
|
||||
// We pass in each individual part of the request here instead of just passing a
|
||||
// registerRequest, as this function serves requests encoded as both
|
||||
// registerRequests and legacyRegisterRequests, which share some attributes but
|
||||
// not all
|
||||
func completeRegistration(
|
||||
ctx context.Context,
|
||||
accountDB *accounts.Database,
|
||||
deviceDB *devices.Database,
|
||||
username, password, appserviceID string,
|
||||
inhibitLogin common.WeakBoolean,
|
||||
displayName *string,
|
||||
displayName, deviceID *string,
|
||||
) util.JSONResponse {
|
||||
if username == "" {
|
||||
return util.JSONResponse{
|
||||
|
@ -773,6 +782,9 @@ func completeRegistration(
|
|||
}
|
||||
}
|
||||
|
||||
// Increment prometheus counter for created users
|
||||
amtRegUsers.Inc()
|
||||
|
||||
// Check whether inhibit_login option is set. If so, don't create an access
|
||||
// token or a device for this user
|
||||
if inhibitLogin {
|
||||
|
@ -793,8 +805,7 @@ func completeRegistration(
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: Use the device ID in the request.
|
||||
dev, err := deviceDB.CreateDevice(ctx, username, nil, token, displayName)
|
||||
dev, err := deviceDB.CreateDevice(ctx, username, deviceID, token, displayName)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
|
@ -802,9 +813,6 @@ func completeRegistration(
|
|||
}
|
||||
}
|
||||
|
||||
// Increment prometheus counter for created users
|
||||
amtRegUsers.Inc()
|
||||
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: registerResponse{
|
||||
|
|
2
testfile
2
testfile
|
@ -149,3 +149,5 @@ Typing events appear in incremental sync
|
|||
Typing events appear in gapped sync
|
||||
Inbound federation of state requires event_id as a mandatory paramater
|
||||
Inbound federation of state_ids requires event_id as a mandatory paramater
|
||||
POST /register returns the same device_id as that in the request
|
||||
POST /login returns the same device_id as that in the request
|
||||
|
|
Loading…
Reference in a new issue