Remember parameters on registration (#2225)

* Remember parameters for sessions
Cleanup sessions on successfully registering or after a while

* Add flakey test

* Update to use time.AfterFunc, add more tests

* Try to drain the channel, if possible
This commit is contained in:
S7evinK 2022-02-25 14:33:02 +01:00 committed by GitHub
parent 4c07374c42
commit cf27e26712
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 149 additions and 29 deletions

View file

@ -162,7 +162,7 @@ func AuthFallback(
} }
// Success. Add recaptcha as a completed login flow // Success. Add recaptcha as a completed login flow
AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
serveSuccess() serveSuccess()
return nil return nil

View file

@ -70,7 +70,7 @@ func UploadCrossSigningDeviceKeys(
if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil { if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil {
return *authErr return *authErr
} }
AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword) sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
uploadReq.UserID = device.UserID uploadReq.UserID = device.UserID
keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes) keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes)

View file

@ -74,7 +74,7 @@ func Password(
if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil { if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil {
return *authErr return *authErr
} }
AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword) sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
// Check the new password strength. // Check the new password strength.
if resErr = validatePassword(r.NewPassword); resErr != nil { if resErr = validatePassword(r.NewPassword); resErr != nil {

View file

@ -72,14 +72,19 @@ func init() {
// sessionsDict keeps track of completed auth stages for each session. // sessionsDict keeps track of completed auth stages for each session.
// It shouldn't be passed by value because it contains a mutex. // It shouldn't be passed by value because it contains a mutex.
type sessionsDict struct { type sessionsDict struct {
sync.Mutex sync.RWMutex
sessions map[string][]authtypes.LoginType sessions map[string][]authtypes.LoginType
params map[string]registerRequest
timer map[string]*time.Timer
} }
// GetCompletedStages returns the completed stages for a session. // defaultTimeout is the timeout used to clean up sessions
func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginType { const defaultTimeOut = time.Minute * 5
d.Lock()
defer d.Unlock() // getCompletedStages returns the completed stages for a session.
func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType {
d.RLock()
defer d.RUnlock()
if completedStages, ok := d.sessions[sessionID]; ok { if completedStages, ok := d.sessions[sessionID]; ok {
return completedStages return completedStages
@ -88,28 +93,79 @@ func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginTyp
return make([]authtypes.LoginType, 0) return make([]authtypes.LoginType, 0)
} }
// addParams adds a registerRequest to a sessionID and starts a timer to delete that registerRequest
func (d *sessionsDict) addParams(sessionID string, r registerRequest) {
d.startTimer(defaultTimeOut, sessionID)
d.Lock()
defer d.Unlock()
d.params[sessionID] = r
}
func (d *sessionsDict) getParams(sessionID string) (registerRequest, bool) {
d.RLock()
defer d.RUnlock()
r, ok := d.params[sessionID]
return r, ok
}
// deleteSession cleans up a given session, either because the registration completed
// successfully, or because a given timeout (default: 5min) was reached.
func (d *sessionsDict) deleteSession(sessionID string) {
d.Lock()
defer d.Unlock()
delete(d.params, sessionID)
delete(d.sessions, sessionID)
// stop the timer, e.g. because the registration was completed
if t, ok := d.timer[sessionID]; ok {
if !t.Stop() {
select {
case <-t.C:
default:
}
}
delete(d.timer, sessionID)
}
}
func newSessionsDict() *sessionsDict { func newSessionsDict() *sessionsDict {
return &sessionsDict{ return &sessionsDict{
sessions: make(map[string][]authtypes.LoginType), sessions: make(map[string][]authtypes.LoginType),
params: make(map[string]registerRequest),
timer: make(map[string]*time.Timer),
} }
} }
// AddCompletedSessionStage records that a session has completed an auth stage. func (d *sessionsDict) startTimer(duration time.Duration, sessionID string) {
func AddCompletedSessionStage(sessionID string, stage authtypes.LoginType) { d.Lock()
sessions.Lock() defer d.Unlock()
defer sessions.Unlock() t, ok := d.timer[sessionID]
if ok {
if !t.Stop() {
<-t.C
}
t.Reset(duration)
return
}
d.timer[sessionID] = time.AfterFunc(duration, func() {
d.deleteSession(sessionID)
})
}
for _, completedStage := range sessions.sessions[sessionID] { // addCompletedSessionStage records that a session has completed an auth stage
// also starts a timer to delete the session once done.
func (d *sessionsDict) addCompletedSessionStage(sessionID string, stage authtypes.LoginType) {
d.startTimer(defaultTimeOut, sessionID)
d.Lock()
defer d.Unlock()
for _, completedStage := range d.sessions[sessionID] {
if completedStage == stage { if completedStage == stage {
return return
} }
} }
sessions.sessions[sessionID] = append(sessions.sessions[sessionID], stage) d.sessions[sessionID] = append(sessions.sessions[sessionID], stage)
} }
var ( var (
// TODO: Remove old sessions. Need to do so on a session-specific timeout.
// sessions stores the completed flow stages for all sessions. Referenced using their sessionID.
sessions = newSessionsDict() sessions = newSessionsDict()
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
) )
@ -167,7 +223,7 @@ func newUserInteractiveResponse(
params map[string]interface{}, params map[string]interface{},
) userInteractiveResponse { ) userInteractiveResponse {
return userInteractiveResponse{ return userInteractiveResponse{
fs, sessions.GetCompletedStages(sessionID), params, sessionID, fs, sessions.getCompletedStages(sessionID), params, sessionID,
} }
} }
@ -645,12 +701,12 @@ func handleRegistrationFlow(
} }
// Add Recaptcha to the list of completed registration stages // Add Recaptcha to the list of completed registration stages
AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
case authtypes.LoginTypeDummy: case authtypes.LoginTypeDummy:
// there is nothing to do // there is nothing to do
// Add Dummy to the list of completed registration stages // Add Dummy to the list of completed registration stages
AddCompletedSessionStage(sessionID, authtypes.LoginTypeDummy) sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeDummy)
case "": case "":
// An empty auth type means that we want to fetch the available // An empty auth type means that we want to fetch the available
@ -666,7 +722,7 @@ func handleRegistrationFlow(
// Check if the user's registration flow has been completed successfully // Check if the user's registration flow has been completed successfully
// A response with current registration flow and remaining available methods // A response with current registration flow and remaining available methods
// will be returned if a flow has not been successfully completed yet // will be returned if a flow has not been successfully completed yet
return checkAndCompleteFlow(sessions.GetCompletedStages(sessionID), return checkAndCompleteFlow(sessions.getCompletedStages(sessionID),
req, r, sessionID, cfg, userAPI) req, r, sessionID, cfg, userAPI)
} }
@ -708,7 +764,7 @@ func handleApplicationServiceRegistration(
// Don't need to worry about appending to registration stages as // Don't need to worry about appending to registration stages as
// application service registration is entirely separate. // application service registration is entirely separate.
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService,
) )
} }
@ -727,11 +783,11 @@ func checkAndCompleteFlow(
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
// This flow was completed, registration can continue // This flow was completed, registration can continue
return completeRegistration( return completeRegistration(
req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser,
) )
} }
sessions.addParams(sessionID, r)
// There are still more stages to complete. // There are still more stages to complete.
// Return the flows and those that have been completed. // Return the flows and those that have been completed.
return util.JSONResponse{ return util.JSONResponse{
@ -750,11 +806,25 @@ func checkAndCompleteFlow(
func completeRegistration( func completeRegistration(
ctx context.Context, ctx context.Context,
userAPI userapi.UserInternalAPI, userAPI userapi.UserInternalAPI,
username, password, appserviceID, ipAddr, userAgent string, username, password, appserviceID, ipAddr, userAgent, sessionID string,
inhibitLogin eventutil.WeakBoolean, inhibitLogin eventutil.WeakBoolean,
displayName, deviceID *string, displayName, deviceID *string,
accType userapi.AccountType, accType userapi.AccountType,
) util.JSONResponse { ) util.JSONResponse {
var registrationOK bool
defer func() {
if registrationOK {
sessions.deleteSession(sessionID)
}
}()
if data, ok := sessions.getParams(sessionID); ok {
username = data.Username
password = data.Password
deviceID = data.DeviceID
displayName = data.InitialDisplayName
inhibitLogin = data.InhibitLogin
}
if username == "" { if username == "" {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
@ -795,6 +865,7 @@ func completeRegistration(
// Check whether inhibit_login option is set. If so, don't create an access // Check whether inhibit_login option is set. If so, don't create an access
// token or a device for this user // token or a device for this user
if inhibitLogin { if inhibitLogin {
registrationOK = true
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: registerResponse{ JSON: registerResponse{
@ -828,6 +899,7 @@ func completeRegistration(
} }
} }
registrationOK = true
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: registerResponse{ JSON: registerResponse{
@ -976,5 +1048,5 @@ func handleSharedSecretRegistration(userAPI userapi.UserInternalAPI, sr *SharedS
if ssrr.Admin { if ssrr.Admin {
accType = userapi.AccountTypeAdmin accType = userapi.AccountTypeAdmin
} }
return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), false, &ssrr.User, &deviceID, accType) return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType)
} }

View file

@ -17,6 +17,7 @@ package routing
import ( import (
"regexp" "regexp"
"testing" "testing"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
@ -140,7 +141,7 @@ func TestFlowCheckingExtraneousIncorrectInput(t *testing.T) {
func TestEmptyCompletedFlows(t *testing.T) { func TestEmptyCompletedFlows(t *testing.T) {
fakeEmptySessions := newSessionsDict() fakeEmptySessions := newSessionsDict()
fakeSessionID := "aRandomSessionIDWhichDoesNotExist" fakeSessionID := "aRandomSessionIDWhichDoesNotExist"
ret := fakeEmptySessions.GetCompletedStages(fakeSessionID) ret := fakeEmptySessions.getCompletedStages(fakeSessionID)
// check for [] // check for []
if ret == nil || len(ret) != 0 { if ret == nil || len(ret) != 0 {
@ -208,3 +209,45 @@ func TestValidationOfApplicationServices(t *testing.T) {
t.Errorf("user_id should not have been valid: @_something_else:localhost") t.Errorf("user_id should not have been valid: @_something_else:localhost")
} }
} }
func TestSessionCleanUp(t *testing.T) {
s := newSessionsDict()
t.Run("session is cleaned up after a while", func(t *testing.T) {
t.Parallel()
dummySession := "helloWorld"
// manually added, as s.addParams() would start the timer with the default timeout
s.params[dummySession] = registerRequest{Username: "Testing"}
s.startTimer(time.Millisecond, dummySession)
time.Sleep(time.Millisecond * 2)
if data, ok := s.getParams(dummySession); ok {
t.Errorf("expected session to be deleted: %+v", data)
}
})
t.Run("session is deleted, once the registration completed", func(t *testing.T) {
t.Parallel()
dummySession := "helloWorld2"
s.startTimer(time.Minute, dummySession)
s.deleteSession(dummySession)
if data, ok := s.getParams(dummySession); ok {
t.Errorf("expected session to be deleted: %+v", data)
}
})
t.Run("session timer is restarted after second call", func(t *testing.T) {
t.Parallel()
dummySession := "helloWorld3"
// the following will start a timer with the default timeout of 5min
s.addParams(dummySession, registerRequest{Username: "Testing"})
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeRecaptcha)
s.addCompletedSessionStage(dummySession, authtypes.LoginTypeDummy)
s.getCompletedStages(dummySession)
// reset the timer with a lower timeout
s.startTimer(time.Millisecond, dummySession)
time.Sleep(time.Millisecond * 2)
if data, ok := s.getParams(dummySession); ok {
t.Errorf("expected session to be deleted: %+v", data)
}
})
}

View file

@ -24,6 +24,7 @@ Local device key changes get to remote servers with correct prev_id
# Flakey # Flakey
Local device key changes appear in /keys/changes Local device key changes appear in /keys/changes
/context/ with lazy_load_members filter works
# we don't support groups # we don't support groups
Remove group category Remove group category

View file

@ -601,3 +601,7 @@ Can query remote device keys using POST after notification
Device deletion propagates over federation Device deletion propagates over federation
Get left notifs in sync and /keys/changes when other user leaves Get left notifs in sync and /keys/changes when other user leaves
Remote banned user is kicked and may not rejoin until unbanned Remote banned user is kicked and may not rejoin until unbanned
registration remembers parameters
registration accepts non-ascii passwords
registration with inhibit_login inhibits login