diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go index b79c573a..044062c4 100644 --- a/clientapi/auth/login_test.go +++ b/clientapi/auth/login_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -47,7 +48,7 @@ func TestLoginFromJSONReader(t *testing.T) { "password": "herpassword", "device_id": "adevice" }`, - WantUsername: "alice", + WantUsername: "@alice:example.com", WantDeviceID: "adevice", }, { @@ -174,7 +175,7 @@ func (ua *fakeUserInternalAPI) QueryAccountByPassword(ctx context.Context, req * return nil } res.Exists = true - res.Account = &uapi.Account{} + res.Account = &uapi.Account{UserID: userutil.MakeUserID(req.Localpart, req.ServerName)} return nil } diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index 4de2b443..f2b0383a 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -101,6 +101,8 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, } } + // If we couldn't find the user by the lower cased localpart, try the provided + // localpart as is. if !res.Exists { err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ Localpart: localpart, @@ -122,5 +124,8 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, } } } + // Set the user, so login.Username() can do the right thing + r.Identifier.User = res.Account.UserID + r.User = res.Account.UserID return &r.Login, nil } diff --git a/clientapi/routing/login_test.go b/clientapi/routing/login_test.go new file mode 100644 index 00000000..d429d7f8 --- /dev/null +++ b/clientapi/routing/login_test.go @@ -0,0 +1,152 @@ +package routing + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi" + uapi "github.com/matrix-org/dendrite/userapi/api" +) + +func TestLogin(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bobUser := &test.User{ID: "@bob:test", AccountType: uapi.AccountTypeUser} + charlie := &test.User{ID: "@Charlie:test", AccountType: uapi.AccountTypeUser} + vhUser := &test.User{ID: "@vhuser:vh1"} + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + base.Cfg.ClientAPI.RateLimiting.Enabled = false + // add a vhost + base.Cfg.Global.VirtualHosts = append(base.Cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}, + }) + + rsAPI := roomserver.NewInternalAPI(base) + // Needed for /login + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + keyAPI.SetUserAPI(userAPI) + + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + Setup(base, &base.Cfg.ClientAPI, nil, nil, userAPI, nil, nil, nil, nil, nil, keyAPI, nil, &base.Cfg.MSCs, nil) + + // Create password + password := util.RandomString(8) + + // create the users + for _, u := range []*test.User{aliceAdmin, bobUser, vhUser, charlie} { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &uapi.PerformAccountCreationResponse{} + + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: password, + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + if !userRes.AccountCreated { + t.Fatalf("account not created") + } + } + + testCases := []struct { + name string + userID string + wantOK bool + }{ + { + name: "aliceAdmin can login", + userID: aliceAdmin.ID, + wantOK: true, + }, + { + name: "bobUser can login", + userID: bobUser.ID, + wantOK: true, + }, + { + name: "vhuser can login", + userID: vhUser.ID, + wantOK: true, + }, + { + name: "bob with uppercase can login", + userID: "@Bob:test", + wantOK: true, + }, + { + name: "Charlie can login (existing uppercase)", + userID: charlie.ID, + wantOK: true, + }, + { + name: "Charlie can not login with lowercase userID", + userID: strings.ToLower(charlie.ID), + wantOK: false, + }, + } + + ctx := context.Background() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{ + "type": authtypes.LoginTypePassword, + "identifier": map[string]interface{}{ + "type": "m.id.user", + "user": tc.userID, + }, + "password": password, + })) + rec := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(rec, req) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("failed to login: %s", rec.Body.String()) + } + + t.Logf("Response: %s", rec.Body.String()) + // get the response + resp := loginResponse{} + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + // everything OK + if !tc.wantOK && resp.AccessToken == "" { + return + } + if tc.wantOK && resp.AccessToken == "" { + t.Fatalf("expected accessToken after successful login but got none: %+v", resp) + } + + devicesResp := &uapi.QueryDevicesResponse{} + if err := userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{UserID: resp.UserID}, devicesResp); err != nil { + t.Fatal(err) + } + for _, dev := range devicesResp.Devices { + // We expect the userID on the device to be the same as resp.UserID + if dev.UserID != resp.UserID { + t.Fatalf("unexpected userID on device: %s", dev.UserID) + } + } + }) + } + }) +}