This commit is contained in:
Neil Alexander 2020-06-17 17:22:28 +01:00
parent ca8683eb5d
commit 88227a699e
2 changed files with 43 additions and 39 deletions

View file

@ -82,6 +82,7 @@ func Login(
} else if req.Method == http.MethodPost { } else if req.Method == http.MethodPost {
var r passwordRequest var r passwordRequest
var acc *api.Account var acc *api.Account
var errJSON *util.JSONResponse
resErr := httputil.UnmarshalJSONRequest(req, &r) resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
@ -94,47 +95,16 @@ func Login(
JSON: jsonerror.BadJSON("'user' must be supplied."), JSON: jsonerror.BadJSON("'user' must be supplied."),
} }
} }
acc, errJSON = r.processUsernamePasswordLoginRequest(req, accountDB, cfg, r.Identifier.User)
util.GetLogger(req.Context()).WithField("user", r.Identifier.User).Info("Processing login request") if errJSON != nil {
return *errJSON
localpart, err := userutil.ParseUsernameParam(r.Identifier.User, &cfg.Matrix.ServerName)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername(err.Error()),
}
}
acc, err = accountDB.GetAccountByPassword(req.Context(), localpart, r.Password)
if err != nil {
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows
// but that would leak the existence of the user.
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("username or password was incorrect, or the account does not exist"),
}
} }
default: default:
// TODO: The below behaviour is deprecated but without it Riot iOS won't log in // TODO: The below behaviour is deprecated but without it Riot iOS won't log in
if r.User != "" { if r.User != "" {
util.GetLogger(req.Context()).WithField("user", r.User).Info("Processing login request") acc, errJSON = r.processUsernamePasswordLoginRequest(req, accountDB, cfg, r.User)
if errJSON != nil {
localpart, err := userutil.ParseUsernameParam(r.User, &cfg.Matrix.ServerName) return *errJSON
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername(err.Error()),
}
}
acc, err = accountDB.GetAccountByPassword(req.Context(), localpart, r.Password)
if err != nil {
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows
// but that would leak the existence of the user.
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("username or password was incorrect, or the account does not exist"),
}
} }
} else { } else {
return util.JSONResponse{ return util.JSONResponse{
@ -187,3 +157,32 @@ func getDevice(
) )
return return
} }
func (r *passwordRequest) processUsernamePasswordLoginRequest(
req *http.Request, accountDB accounts.Database,
cfg *config.Dendrite, username string,
) (acc *api.Account, errJSON *util.JSONResponse) {
util.GetLogger(req.Context()).WithField("user", username).Info("Processing login request")
localpart, err := userutil.ParseUsernameParam(username, &cfg.Matrix.ServerName)
if err != nil {
errJSON = &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername(err.Error()),
}
return
}
acc, err = accountDB.GetAccountByPassword(req.Context(), localpart, r.Password)
if err != nil {
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows
// but that would leak the existence of the user.
errJSON = &util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("username or password was incorrect, or the account does not exist"),
}
return
}
return
}

View file

@ -247,6 +247,12 @@ func (r *messagesReq) retrieveEvents() (
// change the way topological positions are defined (as depth isn't the most // change the way topological positions are defined (as depth isn't the most
// reliable way to define it), it would be easier and less troublesome to // reliable way to define it), it would be easier and less troublesome to
// only have to change it in one place, i.e. the database. // only have to change it in one place, i.e. the database.
start, end, err = r.getStartEnd(events)
return clientEvents, start, end, err
}
func (r *messagesReq) getStartEnd(events []gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) {
start, err = r.db.EventPositionInTopology( start, err = r.db.EventPositionInTopology(
r.ctx, events[0].EventID(), r.ctx, events[0].EventID(),
) )
@ -275,8 +281,7 @@ func (r *messagesReq) retrieveEvents() (
end.Decrement() end.Decrement()
} }
} }
return
return clientEvents, start, end, err
} }
// handleEmptyEventsSlice handles the case where the initial request to the // handleEmptyEventsSlice handles the case where the initial request to the