diff --git a/appservice/appservice.go b/appservice/appservice.go index fd903b98..5f16c10b 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -131,10 +131,11 @@ func generateAppServiceAccount( } var devRes userapi.PerformDeviceCreationResponse err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{ - Localpart: as.SenderLocalpart, - AccessToken: as.ASToken, - DeviceID: &as.SenderLocalpart, - DeviceDisplayName: &as.SenderLocalpart, + Localpart: as.SenderLocalpart, + AccessToken: as.ASToken, + DeviceID: &as.SenderLocalpart, + DeviceDisplayName: &as.SenderLocalpart, + NoDeviceListUpdate: true, }, &devRes) return err } diff --git a/clientapi/auth/auth.go b/clientapi/auth/auth.go index b4c39ae3..c850bf91 100644 --- a/clientapi/auth/auth.go +++ b/clientapi/auth/auth.go @@ -70,11 +70,11 @@ func VerifyUserFromRequest( jsonErr := jsonerror.InternalServerError() return nil, &jsonErr } - if res.Err != nil { - if forbidden, ok := res.Err.(*api.ErrorForbidden); ok { + if res.Err != "" { + if strings.HasPrefix(strings.ToLower(res.Err), "forbidden:") { // TODO: use actual error and no string comparison return nil, &util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden(forbidden.Message), + JSON: jsonerror.Forbidden(res.Err), } } } diff --git a/clientapi/routing/key_backup.go b/clientapi/routing/key_backup.go index ce62a047..9d2ff87f 100644 --- a/clientapi/routing/key_backup.go +++ b/clientapi/routing/key_backup.go @@ -62,12 +62,14 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, return *resErr } var performKeyBackupResp userapi.PerformKeyBackupResponse - userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ UserID: device.UserID, Version: "", AuthData: kb.AuthData, Algorithm: kb.Algorithm, - }, &performKeyBackupResp) + }, &performKeyBackupResp); err != nil { + return jsonerror.InternalServerError() + } if performKeyBackupResp.Error != "" { if performKeyBackupResp.BadInput { return util.JSONResponse{ @@ -123,12 +125,14 @@ func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.UserInter return *resErr } var performKeyBackupResp userapi.PerformKeyBackupResponse - userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ UserID: device.UserID, Version: version, AuthData: kb.AuthData, Algorithm: kb.Algorithm, - }, &performKeyBackupResp) + }, &performKeyBackupResp); err != nil { + return jsonerror.InternalServerError() + } if performKeyBackupResp.Error != "" { if performKeyBackupResp.BadInput { return util.JSONResponse{ @@ -157,11 +161,13 @@ func ModifyKeyBackupVersionAuthData(req *http.Request, userAPI userapi.UserInter // Implements DELETE /_matrix/client/r0/room_keys/version/{version} func DeleteKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse { var performKeyBackupResp userapi.PerformKeyBackupResponse - userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ UserID: device.UserID, Version: version, DeleteBackup: true, - }, &performKeyBackupResp) + }, &performKeyBackupResp); err != nil { + return jsonerror.InternalServerError() + } if performKeyBackupResp.Error != "" { if performKeyBackupResp.BadInput { return util.JSONResponse{ @@ -191,11 +197,13 @@ func UploadBackupKeys( req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string, keys *keyBackupSessionRequest, ) util.JSONResponse { var performKeyBackupResp userapi.PerformKeyBackupResponse - userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ + if err := userAPI.PerformKeyBackup(req.Context(), &userapi.PerformKeyBackupRequest{ UserID: device.UserID, Version: version, Keys: *keys, - }, &performKeyBackupResp) + }, &performKeyBackupResp); err != nil && performKeyBackupResp.Error == "" { + return jsonerror.InternalServerError() + } if performKeyBackupResp.Error != "" { if performKeyBackupResp.BadInput { return util.JSONResponse{ diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 0e55e7ba..1f85cae9 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -67,6 +67,7 @@ func main() { cfg.MediaAPI.InternalAPI.Connect = httpAPIAddr cfg.RoomServer.InternalAPI.Connect = httpAPIAddr cfg.SyncAPI.InternalAPI.Connect = httpAPIAddr + cfg.UserAPI.InternalAPI.Connect = httpAPIAddr options = append(options, basepkg.UseHTTPAPIs) } @@ -102,20 +103,41 @@ func main() { // This is different to rsAPI which can be the http client which doesn't need this dependency rsImpl.SetFederationAPI(fsAPI) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) - userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) - keyAPI.SetUserAPI(userAPI) + keyImpl := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) + keyAPI := keyImpl + if base.UseHTTPAPIs { + keyserver.AddInternalRoutes(base.InternalAPIMux, keyAPI) + keyAPI = base.KeyServerHTTPClient() + } + + userImpl := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) + userAPI := userImpl + if base.UseHTTPAPIs { + userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) + userAPI = base.UserAPIClient() + } if traceInternal { userAPI = &uapi.UserInternalAPITrace{ Impl: userAPI, } } - // needs to be after the SetUserAPI call above + + // TODO: This should use userAPI, not userImpl, but the appservice setup races with + // the listeners and panics at startup if it tries to create appservice accounts + // before the listeners are up. + asAPI := appservice.NewInternalAPI(base, userImpl, rsAPI) if base.UseHTTPAPIs { - keyserver.AddInternalRoutes(base.InternalAPIMux, keyAPI) - keyAPI = base.KeyServerHTTPClient() + appservice.AddInternalRoutes(base.InternalAPIMux, asAPI) + asAPI = base.AppserviceHTTPClient() } + // The underlying roomserver implementation needs to be able to call the fedsender. + // This is different to rsAPI which can be the http client which doesn't need this + // dependency. Other components also need updating after their dependencies are up. + rsImpl.SetFederationAPI(fsAPI) + rsImpl.SetAppserviceAPI(asAPI) + keyImpl.SetUserAPI(userAPI) + eduInputAPI := eduserver.NewInternalAPI( base, cache.New(), userAPI, ) @@ -124,13 +146,6 @@ func main() { eduInputAPI = base.EDUServerClient() } - asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) - if base.UseHTTPAPIs { - appservice.AddInternalRoutes(base.InternalAPIMux, asAPI) - asAPI = base.AppserviceHTTPClient() - } - rsAPI.SetAppserviceAPI(asAPI) - monolith := setup.Monolith{ Config: base.Cfg, AccountDB: accountDB, diff --git a/internal/httputil/http.go b/internal/httputil/http.go index a469c8ac..4527e2b9 100644 --- a/internal/httputil/http.go +++ b/internal/httputil/http.go @@ -23,6 +23,7 @@ import ( "net/url" "strings" + "github.com/matrix-org/dendrite/userapi/api" opentracing "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" ) @@ -72,6 +73,9 @@ func PostJSON( var errorBody struct { Message string `json:"message"` } + if _, ok := response.(*api.PerformKeyBackupResponse); ok { // TODO: remove this, once cross-boundary errors are a thing + return nil + } if msgerr := json.NewDecoder(res.Body).Decode(&errorBody); msgerr == nil { return fmt.Errorf("internal API: %d from %s: %s", res.StatusCode, apiURL, errorBody.Message) } diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index 18ab08be..9044823a 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -7,7 +7,6 @@ import ( "crypto/sha256" "encoding/base64" "encoding/json" - "fmt" "io/ioutil" "net/http" "sort" @@ -504,7 +503,7 @@ type testUserAPI struct { func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error { dev, ok := u.accessTokens[req.AccessToken] if !ok { - res.Err = fmt.Errorf("unknown token") + res.Err = "unknown token" return nil } res.Device = &dev diff --git a/setup/mscs/msc2946/msc2946_test.go b/setup/mscs/msc2946/msc2946_test.go index 441892f3..e8066c34 100644 --- a/setup/mscs/msc2946/msc2946_test.go +++ b/setup/mscs/msc2946/msc2946_test.go @@ -19,7 +19,6 @@ import ( "context" "crypto/ed25519" "encoding/json" - "fmt" "io/ioutil" "net/http" "net/url" @@ -347,7 +346,7 @@ type testUserAPI struct { func (u *testUserAPI) QueryAccessToken(ctx context.Context, req *userapi.QueryAccessTokenRequest, res *userapi.QueryAccessTokenResponse) error { dev, ok := u.accessTokens[req.AccessToken] if !ok { - res.Err = fmt.Errorf("unknown token") + res.Err = "unknown token" return nil } res.Device = &dev diff --git a/userapi/api/api.go b/userapi/api/api.go index 75d06dd6..04609659 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -33,7 +33,7 @@ type UserInternalAPI interface { PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error - PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) + PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error @@ -181,7 +181,7 @@ type QueryAccessTokenRequest struct { // QueryAccessTokenResponse is the response for QueryAccessToken type QueryAccessTokenResponse struct { Device *Device - Err error // e.g ErrorForbidden + Err string // e.g ErrorForbidden } // QueryAccountDataRequest is the request for QueryAccountData @@ -290,6 +290,10 @@ type PerformDeviceCreationRequest struct { IPAddr string // Useragent for this device UserAgent string + // NoDeviceListUpdate determines whether we should avoid sending a device list + // update for this account. Generally the only reason to do this is if the account + // is an appservice account. + NoDeviceListUpdate bool } // PerformDeviceCreationResponse is the response for PerformDeviceCreation diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index 84dcb309..aa069f40 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -74,11 +74,14 @@ func (t *UserInternalAPITrace) PerformOpenIDTokenCreation(ctx context.Context, r util.GetLogger(ctx).Infof("PerformOpenIDTokenCreation req=%+v res=%+v", js(req), js(res)) return err } -func (t *UserInternalAPITrace) PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) { - t.Impl.PerformKeyBackup(ctx, req, res) +func (t *UserInternalAPITrace) PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error { + err := t.Impl.PerformKeyBackup(ctx, req, res) + util.GetLogger(ctx).Infof("PerformKeyBackup req=%+v res=%+v", js(req), js(res)) + return err } func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) { t.Impl.QueryKeyBackup(ctx, req, res) + util.GetLogger(ctx).Infof("QueryKeyBackup req=%+v res=%+v", js(req), js(res)) } func (t *UserInternalAPITrace) QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error { err := t.Impl.QueryProfile(ctx, req, res) diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 4ff8f51d..5d91383d 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -119,6 +119,9 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe } res.DeviceCreated = true res.Device = dev + if req.NoDeviceListUpdate { + return nil + } // create empty device keys and upload them to trigger device list changes return a.deviceListUpdate(dev.UserID, []string{dev.ID}) } @@ -358,8 +361,11 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAccessTokenRequest, res *api.QueryAccessTokenResponse) error { if req.AppServiceUserID != "" { appServiceDevice, err := a.queryAppServiceToken(ctx, req.AccessToken, req.AppServiceUserID) + if err != nil { + res.Err = err.Error() + } res.Device = appServiceDevice - res.Err = err + return nil } device, err := a.DeviceDB.GetDeviceByAccessToken(ctx, req.AccessToken) @@ -455,13 +461,16 @@ func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOp return nil } -func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { +func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) error { // Delete metadata if req.DeleteBackup { if req.Version == "" { res.BadInput = true res.Error = "must specify a version to delete" - return + if res.Error != "" { + return fmt.Errorf(res.Error) + } + return nil } exists, err := a.AccountDB.DeleteKeyBackup(ctx, req.UserID, req.Version) if err != nil { @@ -469,7 +478,10 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform } res.Exists = exists res.Version = req.Version - return + if res.Error != "" { + return fmt.Errorf(res.Error) + } + return nil } // Create metadata if req.Version == "" { @@ -479,7 +491,10 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform } res.Exists = err == nil res.Version = version - return + if res.Error != "" { + return fmt.Errorf(res.Error) + } + return nil } // Update metadata if len(req.Keys.Rooms) == 0 { @@ -489,10 +504,17 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform } res.Exists = err == nil res.Version = req.Version - return + if res.Error != "" { + return fmt.Errorf(res.Error) + } + return nil } // Upload Keys for a specific version metadata a.uploadBackupKeys(ctx, req, res) + if res.Error != "" { + return fmt.Errorf(res.Error) + } + return nil } func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index a89d1a26..1599d463 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -228,7 +228,7 @@ func (h *httpUserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.Que return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) } -func (h *httpUserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) { +func (h *httpUserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) error { span, ctx := opentracing.StartSpanFromContext(ctx, "PerformKeyBackup") defer span.Finish() @@ -237,6 +237,7 @@ func (h *httpUserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Per if err != nil { res.Error = err.Error() } + return nil } func (h *httpUserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) { span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeyBackup") diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 1c1cfdcd..ac05bcd0 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -16,6 +16,7 @@ package inthttp import ( "encoding/json" + "fmt" "net/http" "github.com/gorilla/mux" @@ -234,4 +235,32 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(QueryKeyBackupPath, + httputil.MakeInternalAPI("queryKeyBackup", func(req *http.Request) util.JSONResponse { + request := api.QueryKeyBackupRequest{} + response := api.QueryKeyBackupResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + s.QueryKeyBackup(req.Context(), &request, &response) + if response.Error != "" { + return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", response.Error)) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(PerformKeyBackupPath, + httputil.MakeInternalAPI("performKeyBackup", func(req *http.Request) util.JSONResponse { + request := api.PerformKeyBackupRequest{} + response := api.PerformKeyBackupResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + err := s.PerformKeyBackup(req.Context(), &request, &response) + if err != nil { + return util.JSONResponse{Code: http.StatusBadRequest, JSON: &response} + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) }