Add clientapi tests (#2916)

This PR
- adds several tests for the clientapi, mostly around `/register` and
auth fallback.
- removes the now deprecated `homeserver` field from responses to
`/register` and `/login`
- slightly refactors auth fallback handling
This commit is contained in:
Till 2022-12-23 14:11:11 +01:00 committed by GitHub
parent f47515e38b
commit f762ce1050
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 838 additions and 220 deletions

View file

@ -137,7 +137,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
request := struct {
Password string `json:"password"`
}{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
if err = json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown("Failed to decode request body: " + err.Error()),
@ -150,8 +150,8 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
}
}
if resErr := internal.ValidatePassword(request.Password); resErr != nil {
return *resErr
if err = internal.ValidatePassword(request.Password); err != nil {
return *internal.PasswordResponse(err)
}
updateReq := &userapi.PerformPasswordUpdateRequest{

View file

@ -15,11 +15,11 @@
package routing
import (
"fmt"
"html/template"
"net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/util"
)
@ -101,14 +101,28 @@ func serveTemplate(w http.ResponseWriter, templateHTML string, data map[string]s
func AuthFallback(
w http.ResponseWriter, req *http.Request, authType string,
cfg *config.ClientAPI,
) *util.JSONResponse {
sessionID := req.URL.Query().Get("session")
) {
// We currently only support "m.login.recaptcha", so fail early if that's not requested
if authType == authtypes.LoginTypeRecaptcha {
if !cfg.RecaptchaEnabled {
writeHTTPMessage(w, req,
"Recaptcha login is disabled on this Homeserver",
http.StatusBadRequest,
)
return
}
} else {
writeHTTPMessage(w, req, fmt.Sprintf("Unknown authtype %q", authType), http.StatusNotImplemented)
return
}
sessionID := req.URL.Query().Get("session")
if sessionID == "" {
return writeHTTPMessage(w, req,
writeHTTPMessage(w, req,
"Session ID not provided",
http.StatusBadRequest,
)
return
}
serveRecaptcha := func() {
@ -130,70 +144,44 @@ func AuthFallback(
if req.Method == http.MethodGet {
// Handle Recaptcha
if authType == authtypes.LoginTypeRecaptcha {
if err := checkRecaptchaEnabled(cfg, w, req); err != nil {
return err
}
serveRecaptcha()
return nil
}
return &util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("Unknown auth stage type"),
}
serveRecaptcha()
return
} else if req.Method == http.MethodPost {
// Handle Recaptcha
if authType == authtypes.LoginTypeRecaptcha {
if err := checkRecaptchaEnabled(cfg, w, req); err != nil {
return err
}
clientIP := req.RemoteAddr
err := req.ParseForm()
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed")
res := jsonerror.InternalServerError()
return &res
}
response := req.Form.Get(cfg.RecaptchaFormField)
if err := validateRecaptcha(cfg, response, clientIP); err != nil {
util.GetLogger(req.Context()).Error(err)
return err
}
// Success. Add recaptcha as a completed login flow
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
serveSuccess()
return nil
clientIP := req.RemoteAddr
err := req.ParseForm()
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed")
w.WriteHeader(http.StatusBadRequest)
serveRecaptcha()
return
}
return &util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("Unknown auth stage type"),
response := req.Form.Get(cfg.RecaptchaFormField)
err = validateRecaptcha(cfg, response, clientIP)
switch err {
case ErrMissingResponse:
w.WriteHeader(http.StatusBadRequest)
serveRecaptcha() // serve the initial page again, instead of nothing
return
case ErrInvalidCaptcha:
w.WriteHeader(http.StatusUnauthorized)
serveRecaptcha()
return
case nil:
default: // something else failed
util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha")
serveRecaptcha()
return
}
}
return &util.JSONResponse{
Code: http.StatusMethodNotAllowed,
JSON: jsonerror.NotFound("Bad method"),
}
}
// checkRecaptchaEnabled creates an error response if recaptcha is not usable on homeserver.
func checkRecaptchaEnabled(
cfg *config.ClientAPI,
w http.ResponseWriter,
req *http.Request,
) *util.JSONResponse {
if !cfg.RecaptchaEnabled {
return writeHTTPMessage(w, req,
"Recaptcha login is disabled on this Homeserver",
http.StatusBadRequest,
)
// Success. Add recaptcha as a completed login flow
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
serveSuccess()
return
}
return nil
writeHTTPMessage(w, req, "Bad method", http.StatusMethodNotAllowed)
}
// writeHTTPMessage writes the given header and message to the HTTP response writer.
@ -201,13 +189,10 @@ func checkRecaptchaEnabled(
func writeHTTPMessage(
w http.ResponseWriter, req *http.Request,
message string, header int,
) *util.JSONResponse {
) {
w.WriteHeader(header)
_, err := w.Write([]byte(message))
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("w.Write failed")
res := jsonerror.InternalServerError()
return &res
}
return nil
}

View file

@ -0,0 +1,149 @@
package routing
import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test/testrig"
)
func Test_AuthFallback(t *testing.T) {
base, _, _ := testrig.Base(nil)
defer base.Close()
for _, useHCaptcha := range []bool{false, true} {
for _, recaptchaEnabled := range []bool{false, true} {
for _, wantErr := range []bool{false, true} {
t.Run(fmt.Sprintf("useHCaptcha(%v) - recaptchaEnabled(%v) - wantErr(%v)", useHCaptcha, recaptchaEnabled, wantErr), func(t *testing.T) {
// Set the defaults for each test
base.Cfg.ClientAPI.Defaults(config.DefaultOpts{Generate: true, Monolithic: true})
base.Cfg.ClientAPI.RecaptchaEnabled = recaptchaEnabled
base.Cfg.ClientAPI.RecaptchaPublicKey = "pub"
base.Cfg.ClientAPI.RecaptchaPrivateKey = "priv"
if useHCaptcha {
base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = "https://hcaptcha.com/siteverify"
base.Cfg.ClientAPI.RecaptchaApiJsUrl = "https://js.hcaptcha.com/1/api.js"
base.Cfg.ClientAPI.RecaptchaFormField = "h-captcha-response"
base.Cfg.ClientAPI.RecaptchaSitekeyClass = "h-captcha"
}
cfgErrs := &config.ConfigErrors{}
base.Cfg.ClientAPI.Verify(cfgErrs, true)
if len(*cfgErrs) > 0 {
t.Fatalf("(hCaptcha=%v) unexpected config errors: %s", useHCaptcha, cfgErrs.Error())
}
req := httptest.NewRequest(http.MethodGet, "/?session=1337", nil)
rec := httptest.NewRecorder()
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
if !recaptchaEnabled {
if rec.Code != http.StatusBadRequest {
t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusBadRequest)
}
if rec.Body.String() != "Recaptcha login is disabled on this Homeserver" {
t.Fatalf("unexpected response body: %s", rec.Body.String())
}
} else {
if !strings.Contains(rec.Body.String(), base.Cfg.ClientAPI.RecaptchaSitekeyClass) {
t.Fatalf("body does not contain %s: %s", base.Cfg.ClientAPI.RecaptchaSitekeyClass, rec.Body.String())
}
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if wantErr {
_, _ = w.Write([]byte(`{"success":false}`))
return
}
_, _ = w.Write([]byte(`{"success":true}`))
}))
defer srv.Close() // nolint: errcheck
base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = srv.URL
// check the result after sending the captcha
req = httptest.NewRequest(http.MethodPost, "/?session=1337", nil)
req.Form = url.Values{}
req.Form.Add(base.Cfg.ClientAPI.RecaptchaFormField, "someRandomValue")
rec = httptest.NewRecorder()
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
if recaptchaEnabled {
if !wantErr {
if rec.Code != http.StatusOK {
t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusOK)
}
if rec.Body.String() != successTemplate {
t.Fatalf("unexpected response: %s, want %s", rec.Body.String(), successTemplate)
}
} else {
if rec.Code != http.StatusUnauthorized {
t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusUnauthorized)
}
wantString := "Authentication"
if !strings.Contains(rec.Body.String(), wantString) {
t.Fatalf("expected response to contain '%s', but didn't: %s", wantString, rec.Body.String())
}
}
} else {
if rec.Code != http.StatusBadRequest {
t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusBadRequest)
}
if rec.Body.String() != "Recaptcha login is disabled on this Homeserver" {
t.Fatalf("unexpected response: %s, want %s", rec.Body.String(), "successTemplate")
}
}
})
}
}
}
t.Run("unknown fallbacks are handled correctly", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/?session=1337", nil)
rec := httptest.NewRecorder()
AuthFallback(rec, req, "DoesNotExist", &base.Cfg.ClientAPI)
if rec.Code != http.StatusNotImplemented {
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusNotImplemented)
}
})
t.Run("unknown methods are handled correctly", func(t *testing.T) {
req := httptest.NewRequest(http.MethodDelete, "/?session=1337", nil)
rec := httptest.NewRecorder()
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
if rec.Code != http.StatusMethodNotAllowed {
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusMethodNotAllowed)
}
})
t.Run("missing session parameter is handled correctly", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
if rec.Code != http.StatusBadRequest {
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest)
}
})
t.Run("missing session parameter is handled correctly", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
if rec.Code != http.StatusBadRequest {
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest)
}
})
t.Run("missing 'response' is handled correctly", func(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/?session=1337", nil)
rec := httptest.NewRecorder()
AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI)
if rec.Code != http.StatusBadRequest {
t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest)
}
})
}

View file

@ -23,15 +23,13 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
type loginResponse struct {
UserID string `json:"user_id"`
AccessToken string `json:"access_token"`
HomeServer gomatrixserverlib.ServerName `json:"home_server"`
DeviceID string `json:"device_id"`
UserID string `json:"user_id"`
AccessToken string `json:"access_token"`
DeviceID string `json:"device_id"`
}
type flows struct {
@ -116,7 +114,6 @@ func completeAuth(
JSON: loginResponse{
UserID: performRes.Device.UserID,
AccessToken: performRes.Device.AccessToken,
HomeServer: serverName,
DeviceID: performRes.Device.ID,
},
}

View file

@ -82,8 +82,8 @@ func Password(
sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
// Check the new password strength.
if resErr = internal.ValidatePassword(r.NewPassword); resErr != nil {
return *resErr
if err := internal.ValidatePassword(r.NewPassword); err != nil {
return *internal.PasswordResponse(err)
}
// Get the local part.

View file

@ -18,12 +18,12 @@ package routing
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"regexp"
"sort"
"strconv"
"strings"
@ -60,10 +60,7 @@ var (
)
)
const (
maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain
sessionIDLength = 24
)
const sessionIDLength = 24
// sessionsDict keeps track of completed auth stages for each session.
// It shouldn't be passed by value because it contains a mutex.
@ -198,8 +195,7 @@ func (d *sessionsDict) getDeviceToDelete(sessionID string) (string, bool) {
}
var (
sessions = newSessionsDict()
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
sessions = newSessionsDict()
)
// registerRequest represents the submitted registration request.
@ -262,10 +258,9 @@ func newUserInteractiveResponse(
// http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register
type registerResponse struct {
UserID string `json:"user_id"`
AccessToken string `json:"access_token,omitempty"`
HomeServer gomatrixserverlib.ServerName `json:"home_server"`
DeviceID string `json:"device_id,omitempty"`
UserID string `json:"user_id"`
AccessToken string `json:"access_token,omitempty"`
DeviceID string `json:"device_id,omitempty"`
}
// recaptchaResponse represents the HTTP response from a Google Recaptcha server
@ -276,66 +271,28 @@ type recaptchaResponse struct {
ErrorCodes []int `json:"error-codes"`
}
// validateUsername returns an error response if the username is invalid
func validateUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse {
// https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)),
}
} else if !validUsernameRegex.MatchString(localpart) {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
}
} else if localpart[0] == '_' { // Regex checks its not a zero length string
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username cannot start with a '_'"),
}
}
return nil
}
// validateApplicationServiceUsername returns an error response if the username is invalid for an application service
func validateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse {
if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)),
}
} else if !validUsernameRegex.MatchString(localpart) {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"),
}
}
return nil
}
var (
ErrInvalidCaptcha = errors.New("invalid captcha response")
ErrMissingResponse = errors.New("captcha response is required")
ErrCaptchaDisabled = errors.New("captcha registration is disabled")
)
// validateRecaptcha returns an error response if the captcha response is invalid
func validateRecaptcha(
cfg *config.ClientAPI,
response string,
clientip string,
) *util.JSONResponse {
) error {
ip, _, _ := net.SplitHostPort(clientip)
if !cfg.RecaptchaEnabled {
return &util.JSONResponse{
Code: http.StatusConflict,
JSON: jsonerror.Unknown("Captcha registration is disabled"),
}
return ErrCaptchaDisabled
}
if response == "" {
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("Captcha response is required"),
}
return ErrMissingResponse
}
// Make a POST request to Google's API to check the captcha response
// Make a POST request to the captcha provider API to check the captcha response
resp, err := http.PostForm(cfg.RecaptchaSiteVerifyAPI,
url.Values{
"secret": {cfg.RecaptchaPrivateKey},
@ -345,10 +302,7 @@ func validateRecaptcha(
)
if err != nil {
return &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.BadJSON("Error in requesting validation of captcha response"),
}
return err
}
// Close the request once we're finishing reading from it
@ -358,25 +312,16 @@ func validateRecaptcha(
var r recaptchaResponse
body, err := io.ReadAll(resp.Body)
if err != nil {
return &util.JSONResponse{
Code: http.StatusGatewayTimeout,
JSON: jsonerror.Unknown("Error in contacting captcha server" + err.Error()),
}
return err
}
err = json.Unmarshal(body, &r)
if err != nil {
return &util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: jsonerror.BadJSON("Error in unmarshaling captcha server's response: " + err.Error()),
}
return err
}
// Check that we received a "success"
if !r.Success {
return &util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: jsonerror.BadJSON("Invalid captcha response. Please try again."),
}
return ErrInvalidCaptcha
}
return nil
}
@ -508,8 +453,8 @@ func validateApplicationService(
}
// Check username application service is trying to register is valid
if err := validateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil {
return "", err
if err := internal.ValidateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil {
return "", internal.UsernameResponse(err)
}
// No errors, registration valid
@ -564,15 +509,12 @@ func Register(
if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil {
return *resErr
}
if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil {
r.Username, r.ServerName = l, d
}
if req.URL.Query().Get("kind") == "guest" {
return handleGuestRegistration(req, r, cfg, userAPI)
}
// Don't allow numeric usernames less than MAX_INT64.
if _, err := strconv.ParseInt(r.Username, 10, 64); err == nil {
if _, err = strconv.ParseInt(r.Username, 10, 64); err == nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"),
@ -584,7 +526,7 @@ func Register(
ServerName: r.ServerName,
}
nres := &userapi.QueryNumericLocalpartResponse{}
if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
if err = userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
return jsonerror.InternalServerError()
}
@ -601,8 +543,8 @@ func Register(
case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil:
// Spec-compliant case (the access_token is specified and the login type
// is correctly set, so it's an appservice registration)
if resErr := validateApplicationServiceUsername(r.Username, r.ServerName); resErr != nil {
return *resErr
if err = internal.ValidateApplicationServiceUsername(r.Username, r.ServerName); err != nil {
return *internal.UsernameResponse(err)
}
case accessTokenErr == nil:
// Non-spec-compliant case (the access_token is specified but the login
@ -614,12 +556,12 @@ func Register(
default:
// Spec-compliant case (neither the access_token nor the login type are
// specified, so it's a normal user registration)
if resErr := validateUsername(r.Username, r.ServerName); resErr != nil {
return *resErr
if err = internal.ValidateUsername(r.Username, r.ServerName); err != nil {
return *internal.UsernameResponse(err)
}
}
if resErr := internal.ValidatePassword(r.Password); resErr != nil {
return *resErr
if err = internal.ValidatePassword(r.Password); err != nil {
return *internal.PasswordResponse(err)
}
logger := util.GetLogger(req.Context())
@ -697,7 +639,6 @@ func handleGuestRegistration(
JSON: registerResponse{
UserID: devRes.Device.UserID,
AccessToken: devRes.Device.AccessToken,
HomeServer: res.Account.ServerName,
DeviceID: devRes.Device.ID,
},
}
@ -761,9 +702,18 @@ func handleRegistrationFlow(
switch r.Auth.Type {
case authtypes.LoginTypeRecaptcha:
// Check given captcha response
resErr := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr)
if resErr != nil {
return *resErr
err := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr)
switch err {
case ErrCaptchaDisabled:
return util.JSONResponse{Code: http.StatusForbidden, JSON: jsonerror.Unknown(err.Error())}
case ErrMissingResponse:
return util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(err.Error())}
case ErrInvalidCaptcha:
return util.JSONResponse{Code: http.StatusUnauthorized, JSON: jsonerror.BadJSON(err.Error())}
case nil:
default:
util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha")
return util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()}
}
// Add Recaptcha to the list of completed registration stages
@ -924,8 +874,7 @@ func completeRegistration(
return util.JSONResponse{
Code: http.StatusOK,
JSON: registerResponse{
UserID: userutil.MakeUserID(username, accRes.Account.ServerName),
HomeServer: accRes.Account.ServerName,
UserID: userutil.MakeUserID(username, accRes.Account.ServerName),
},
}
}
@ -958,7 +907,6 @@ func completeRegistration(
result := registerResponse{
UserID: devRes.Device.UserID,
AccessToken: devRes.Device.AccessToken,
HomeServer: accRes.Account.ServerName,
DeviceID: devRes.Device.ID,
}
sessions.addCompletedRegistration(sessionID, result)
@ -1054,8 +1002,8 @@ func RegisterAvailable(
}
}
if err := validateUsername(username, domain); err != nil {
return *err
if err := internal.ValidateUsername(username, domain); err != nil {
return *internal.UsernameResponse(err)
}
// Check if this username is reserved by an application service
@ -1117,11 +1065,11 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien
// downcase capitals
ssrr.User = strings.ToLower(ssrr.User)
if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil {
return *resErr
if err = internal.ValidateUsername(ssrr.User, cfg.Matrix.ServerName); err != nil {
return *internal.UsernameResponse(err)
}
if resErr := internal.ValidatePassword(ssrr.Password); resErr != nil {
return *resErr
if err = internal.ValidatePassword(ssrr.Password); err != nil {
return *internal.PasswordResponse(err)
}
deviceID := "shared_secret_registration"

View file

@ -15,12 +15,27 @@
package routing
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"reflect"
"regexp"
"strings"
"testing"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver"
"github.com/matrix-org/dendrite/roomserver"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/util"
)
var (
@ -264,3 +279,294 @@ func TestSessionCleanUp(t *testing.T) {
}
})
}
func Test_register(t *testing.T) {
testCases := []struct {
name string
kind string
password string
username string
loginType string
forceEmpty bool
registrationDisabled bool
guestsDisabled bool
enableRecaptcha bool
captchaBody string
wantResponse util.JSONResponse
}{
{
name: "disallow guests",
kind: "guest",
guestsDisabled: true,
wantResponse: util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(`Guest registration is disabled on "test"`),
},
},
{
name: "allow guests",
kind: "guest",
},
{
name: "unknown login type",
loginType: "im.not.known",
wantResponse: util.JSONResponse{
Code: http.StatusNotImplemented,
JSON: jsonerror.Unknown("unknown/unimplemented auth type"),
},
},
{
name: "disabled registration",
registrationDisabled: true,
wantResponse: util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden(`Registration is disabled on "test"`),
},
},
{
name: "successful registration, numeric ID",
username: "",
password: "someRandomPassword",
forceEmpty: true,
},
{
name: "successful registration",
username: "success",
},
{
name: "failing registration - user already exists",
username: "success",
wantResponse: util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.UserInUse("Desired user ID is already taken."),
},
},
{
name: "successful registration uppercase username",
username: "LOWERCASED", // this is going to be lower-cased
},
{
name: "invalid username",
username: "#totalyNotValid",
wantResponse: *internal.UsernameResponse(internal.ErrUsernameInvalid),
},
{
name: "numeric username is forbidden",
username: "1337",
wantResponse: util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"),
},
},
{
name: "disabled recaptcha login",
loginType: authtypes.LoginTypeRecaptcha,
wantResponse: util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Unknown(ErrCaptchaDisabled.Error()),
},
},
{
name: "enabled recaptcha, no response defined",
enableRecaptcha: true,
loginType: authtypes.LoginTypeRecaptcha,
wantResponse: util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(ErrMissingResponse.Error()),
},
},
{
name: "invalid captcha response",
enableRecaptcha: true,
loginType: authtypes.LoginTypeRecaptcha,
captchaBody: `notvalid`,
wantResponse: util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: jsonerror.BadJSON(ErrInvalidCaptcha.Error()),
},
},
{
name: "valid captcha response",
enableRecaptcha: true,
loginType: authtypes.LoginTypeRecaptcha,
captchaBody: `success`,
},
{
name: "captcha invalid from remote",
enableRecaptcha: true,
loginType: authtypes.LoginTypeRecaptcha,
captchaBody: `i should fail for other reasons`,
wantResponse: util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()},
},
}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, baseClose := testrig.CreateBaseDendrite(t, dbType)
defer baseClose()
rsAPI := roomserver.NewInternalAPI(base)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil)
keyAPI.SetUserAPI(userAPI)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.enableRecaptcha {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Fatal(err)
}
response := r.Form.Get("response")
// Respond with valid JSON or no JSON at all to test happy/error cases
switch response {
case "success":
json.NewEncoder(w).Encode(recaptchaResponse{Success: true})
case "notvalid":
json.NewEncoder(w).Encode(recaptchaResponse{Success: false})
default:
}
}))
defer srv.Close()
base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = srv.URL
}
if err := base.Cfg.Derive(); err != nil {
t.Fatalf("failed to derive config: %s", err)
}
base.Cfg.ClientAPI.RecaptchaEnabled = tc.enableRecaptcha
base.Cfg.ClientAPI.RegistrationDisabled = tc.registrationDisabled
base.Cfg.ClientAPI.GuestsDisabled = tc.guestsDisabled
if tc.kind == "" {
tc.kind = "user"
}
if tc.password == "" && !tc.forceEmpty {
tc.password = "someRandomPassword"
}
if tc.username == "" && !tc.forceEmpty {
tc.username = "valid"
}
if tc.loginType == "" {
tc.loginType = "m.login.dummy"
}
reg := registerRequest{
Password: tc.password,
Username: tc.username,
}
body := &bytes.Buffer{}
err := json.NewEncoder(body).Encode(reg)
if err != nil {
t.Fatal(err)
}
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/?kind=%s", tc.kind), body)
resp := Register(req, userAPI, &base.Cfg.ClientAPI)
t.Logf("Resp: %+v", resp)
// The first request should return a userInteractiveResponse
switch r := resp.JSON.(type) {
case userInteractiveResponse:
// Check that the flows are the ones we configured
if !reflect.DeepEqual(r.Flows, base.Cfg.Derived.Registration.Flows) {
t.Fatalf("unexpected registration flows: %+v, want %+v", r.Flows, base.Cfg.Derived.Registration.Flows)
}
case *jsonerror.MatrixError:
if !reflect.DeepEqual(tc.wantResponse, resp) {
t.Fatalf("(%s), unexpected response: %+v, want: %+v", tc.name, resp, tc.wantResponse)
}
return
case registerResponse:
// this should only be possible on guest user registration, never for normal users
if tc.kind != "guest" {
t.Fatalf("got register response on first request: %+v", r)
}
// assert we've got a UserID, AccessToken and DeviceID
if r.UserID == "" {
t.Fatalf("missing userID in response")
}
if r.AccessToken == "" {
t.Fatalf("missing accessToken in response")
}
if r.DeviceID == "" {
t.Fatalf("missing deviceID in response")
}
return
default:
t.Logf("Got response: %T", resp.JSON)
}
// If we reached this, we should have received a UIA response
uia, ok := resp.JSON.(userInteractiveResponse)
if !ok {
t.Fatalf("did not receive a userInteractiveResponse: %T", resp.JSON)
}
t.Logf("%+v", uia)
// Register the user
reg.Auth = authDict{
Type: authtypes.LoginType(tc.loginType),
Session: uia.Session,
}
if tc.captchaBody != "" {
reg.Auth.Response = tc.captchaBody
}
dummy := "dummy"
reg.DeviceID = &dummy
reg.InitialDisplayName = &dummy
reg.Type = authtypes.LoginType(tc.loginType)
err = json.NewEncoder(body).Encode(reg)
if err != nil {
t.Fatal(err)
}
req = httptest.NewRequest(http.MethodPost, "/", body)
resp = Register(req, userAPI, &base.Cfg.ClientAPI)
switch resp.JSON.(type) {
case *jsonerror.MatrixError:
if !reflect.DeepEqual(tc.wantResponse, resp) {
t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse)
}
return
case util.JSONResponse:
if !reflect.DeepEqual(tc.wantResponse, resp) {
t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse)
}
return
}
rr, ok := resp.JSON.(registerResponse)
if !ok {
t.Fatalf("expected a registerresponse, got %T", resp.JSON)
}
// validate the response
if tc.forceEmpty {
// when not supplying a username, one will be generated. Given this _SHOULD_ be
// the second user, set the username accordingly
reg.Username = "2"
}
wantUserID := strings.ToLower(fmt.Sprintf("@%s:%s", reg.Username, "test"))
if wantUserID != rr.UserID {
t.Fatalf("unexpected userID: %s, want %s", rr.UserID, wantUserID)
}
if rr.DeviceID != *reg.DeviceID {
t.Fatalf("unexpected deviceID: %s, want %s", rr.DeviceID, *reg.DeviceID)
}
if rr.AccessToken == "" {
t.Fatalf("missing accessToken in response")
}
})
}
})
}

View file

@ -639,9 +639,9 @@ func Setup(
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
v3mux.Handle("/auth/{authType}/fallback/web",
httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) *util.JSONResponse {
httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) {
vars := mux.Vars(req)
return AuthFallback(w, req, vars["authType"], cfg)
AuthFallback(w, req, vars["authType"], cfg)
}),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)