Merge branch 'master' into matthew/peeking

This commit is contained in:
Neil Alexander 2020-09-04 16:08:12 +01:00
commit 64fe2741c0
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
33 changed files with 289 additions and 101 deletions

View file

@ -130,7 +130,7 @@ func (m *DendriteMonolith) Start() {
asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI)
stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer) stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer)
fsAPI := federationsender.NewInternalAPI( fsAPI := federationsender.NewInternalAPI(
base, federation, rsAPI, stateAPI, keyRing, base, federation, rsAPI, keyRing,
) )
ygg.SetSessionFunc(func(address string) { ygg.SetSessionFunc(func(address string) {

View file

@ -5,6 +5,7 @@ type LoginType string
// The relevant login types implemented in Dendrite // The relevant login types implemented in Dendrite
const ( const (
LoginTypePassword = "m.login.password"
LoginTypeDummy = "m.login.dummy" LoginTypeDummy = "m.login.dummy"
LoginTypeSharedSecret = "org.matrix.login.shared_secret" LoginTypeSharedSecret = "org.matrix.login.shared_secret"
LoginTypeRecaptcha = "m.login.recaptcha" LoginTypeRecaptcha = "m.login.recaptcha"

View file

@ -19,7 +19,6 @@ import (
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api"
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
@ -94,10 +93,10 @@ func GetMemberships(
func GetJoinedRooms( func GetJoinedRooms(
req *http.Request, req *http.Request,
device *userapi.Device, device *userapi.Device,
stateAPI currentstateAPI.CurrentStateInternalAPI, rsAPI api.RoomserverInternalAPI,
) util.JSONResponse { ) util.JSONResponse {
var res currentstateAPI.QueryRoomsForUserResponse var res api.QueryRoomsForUserResponse
err := stateAPI.QueryRoomsForUser(req.Context(), &currentstateAPI.QueryRoomsForUserRequest{ err := rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{
UserID: device.UserID, UserID: device.UserID,
WantMembership: "join", WantMembership: "join",
}, &res) }, &res)

View file

@ -0,0 +1,127 @@
package routing
import (
"net/http"
"github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/userapi/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
type newPasswordRequest struct {
NewPassword string `json:"new_password"`
LogoutDevices bool `json:"logout_devices"`
Auth newPasswordAuth `json:"auth"`
}
type newPasswordAuth struct {
Type string `json:"type"`
Session string `json:"session"`
auth.PasswordRequest
}
func Password(
req *http.Request,
userAPI userapi.UserInternalAPI,
accountDB accounts.Database,
device *api.Device,
cfg *config.ClientAPI,
) util.JSONResponse {
// Check that the existing password is right.
var r newPasswordRequest
r.LogoutDevices = true
// Unmarshal the request.
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
return *resErr
}
// Retrieve or generate the sessionID
sessionID := r.Auth.Session
if sessionID == "" {
// Generate a new, random session ID
sessionID = util.RandomString(sessionIDLength)
}
// Require password auth to change the password.
if r.Auth.Type != authtypes.LoginTypePassword {
return util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: newUserInteractiveResponse(
sessionID,
[]authtypes.Flow{
{
Stages: []authtypes.LoginType{authtypes.LoginTypePassword},
},
},
nil,
),
}
}
// Check if the existing password is correct.
typePassword := auth.LoginTypePassword{
GetAccountByPassword: accountDB.GetAccountByPassword,
Config: cfg,
}
if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil {
return *authErr
}
AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
// Check the new password strength.
if resErr = validatePassword(r.NewPassword); resErr != nil {
return *resErr
}
// Get the local part.
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
}
// Ask the user API to perform the password change.
passwordReq := &userapi.PerformPasswordUpdateRequest{
Localpart: localpart,
Password: r.NewPassword,
}
passwordRes := &userapi.PerformPasswordUpdateResponse{}
if err := userAPI.PerformPasswordUpdate(req.Context(), passwordReq, passwordRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("PerformPasswordUpdate failed")
return jsonerror.InternalServerError()
}
if !passwordRes.PasswordUpdated {
util.GetLogger(req.Context()).Error("Expected password to have been updated but wasn't")
return jsonerror.InternalServerError()
}
// If the request asks us to log out all other devices then
// ask the user API to do that.
if r.LogoutDevices {
logoutReq := &userapi.PerformDeviceDeletionRequest{
UserID: device.UserID,
DeviceIDs: nil,
ExceptDeviceID: device.ID,
}
logoutRes := &userapi.PerformDeviceDeletionResponse{}
if err := userAPI.PerformDeviceDeletion(req.Context(), logoutReq, logoutRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed")
return jsonerror.InternalServerError()
}
}
// Return a success code.
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}

View file

@ -140,8 +140,8 @@ func SetAvatarURL(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
var res currentstateAPI.QueryRoomsForUserResponse var res api.QueryRoomsForUserResponse
err = stateAPI.QueryRoomsForUser(req.Context(), &currentstateAPI.QueryRoomsForUserRequest{ err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{
UserID: device.UserID, UserID: device.UserID,
WantMembership: "join", WantMembership: "join",
}, &res) }, &res)
@ -258,8 +258,8 @@ func SetDisplayName(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
var res currentstateAPI.QueryRoomsForUserResponse var res api.QueryRoomsForUserResponse
err = stateAPI.QueryRoomsForUser(req.Context(), &currentstateAPI.QueryRoomsForUserRequest{ err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{
UserID: device.UserID, UserID: device.UserID,
WantMembership: "join", WantMembership: "join",
}, &res) }, &res)

View file

@ -118,7 +118,7 @@ func Setup(
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/joined_rooms", r0mux.Handle("/joined_rooms",
httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetJoinedRooms(req, device, stateAPI) return GetJoinedRooms(req, device, rsAPI)
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/join", r0mux.Handle("/rooms/{roomID}/join",
@ -428,6 +428,15 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/account/password",
httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.rateLimit(req); r != nil {
return *r
}
return Password(req, userAPI, accountDB, device, cfg)
}),
).Methods(http.MethodPost, http.MethodOptions)
// Stub endpoints required by Riot // Stub endpoints required by Riot
r0mux.Handle("/login", r0mux.Handle("/login",

View file

@ -162,7 +162,7 @@ func main() {
) )
asAPI := appservice.NewInternalAPI(&base.Base, userAPI, rsAPI) asAPI := appservice.NewInternalAPI(&base.Base, userAPI, rsAPI)
fsAPI := federationsender.NewInternalAPI( fsAPI := federationsender.NewInternalAPI(
&base.Base, federation, rsAPI, stateAPI, keyRing, &base.Base, federation, rsAPI, keyRing,
) )
rsAPI.SetFederationSenderAPI(fsAPI) rsAPI.SetFederationSenderAPI(fsAPI)
provider := newPublicRoomsProvider(base.LibP2PPubsub, rsAPI, stateAPI) provider := newPublicRoomsProvider(base.LibP2PPubsub, rsAPI, stateAPI)

View file

@ -115,7 +115,7 @@ func main() {
asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI)
stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer) stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer)
fsAPI := federationsender.NewInternalAPI( fsAPI := federationsender.NewInternalAPI(
base, federation, rsAPI, stateAPI, keyRing, base, federation, rsAPI, keyRing,
) )
ygg.SetSessionFunc(func(address string) { ygg.SetSessionFunc(func(address string) {

View file

@ -31,7 +31,7 @@ func main() {
rsAPI := base.RoomserverHTTPClient() rsAPI := base.RoomserverHTTPClient()
fsAPI := federationsender.NewInternalAPI( fsAPI := federationsender.NewInternalAPI(
base, federation, rsAPI, base.CurrentStateAPIClient(), keyRing, base, federation, rsAPI, keyRing,
) )
federationsender.AddInternalRoutes(base.InternalAPIMux, fsAPI) federationsender.AddInternalRoutes(base.InternalAPIMux, fsAPI)

View file

@ -98,7 +98,7 @@ func main() {
stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer) stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer)
fsAPI := federationsender.NewInternalAPI( fsAPI := federationsender.NewInternalAPI(
base, federation, rsAPI, stateAPI, keyRing, base, federation, rsAPI, keyRing,
) )
if base.UseHTTPAPIs { if base.UseHTTPAPIs {
federationsender.AddInternalRoutes(base.InternalAPIMux, fsAPI) federationsender.AddInternalRoutes(base.InternalAPIMux, fsAPI)

View file

@ -210,7 +210,7 @@ func main() {
asQuery := appservice.NewInternalAPI( asQuery := appservice.NewInternalAPI(
base, userAPI, rsAPI, base, userAPI, rsAPI,
) )
fedSenderAPI := federationsender.NewInternalAPI(base, federation, rsAPI, stateAPI, &keyRing) fedSenderAPI := federationsender.NewInternalAPI(base, federation, rsAPI, &keyRing)
rsAPI.SetFederationSenderAPI(fedSenderAPI) rsAPI.SetFederationSenderAPI(fedSenderAPI)
p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI, federation) p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI, federation)

View file

@ -21,22 +21,10 @@ import (
) )
type CurrentStateInternalAPI interface { type CurrentStateInternalAPI interface {
// QueryRoomsForUser retrieves a list of room IDs matching the given query.
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
// QueryBulkStateContent does a bulk query for state event content in the given rooms. // QueryBulkStateContent does a bulk query for state event content in the given rooms.
QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error
} }
type QueryRoomsForUserRequest struct {
UserID string
// The desired membership of the user. If this is the empty string then no rooms are returned.
WantMembership string
}
type QueryRoomsForUserResponse struct {
RoomIDs []string
}
type QueryBulkStateContentRequest struct { type QueryBulkStateContentRequest struct {
// Returns state events in these rooms // Returns state events in these rooms
RoomIDs []string RoomIDs []string

View file

@ -26,15 +26,6 @@ type CurrentStateInternalAPI struct {
DB storage.Database DB storage.Database
} }
func (a *CurrentStateInternalAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error {
roomIDs, err := a.DB.GetRoomsByMembership(ctx, req.UserID, req.WantMembership)
if err != nil {
return err
}
res.RoomIDs = roomIDs
return nil
}
func (a *CurrentStateInternalAPI) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error { func (a *CurrentStateInternalAPI) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error {
events, err := a.DB.GetBulkStateContent(ctx, req.RoomIDs, req.StateTuples, req.AllowWildcards) events, err := a.DB.GetBulkStateContent(ctx, req.RoomIDs, req.StateTuples, req.AllowWildcards)
if err != nil { if err != nil {

View file

@ -50,18 +50,6 @@ type httpCurrentStateInternalAPI struct {
httpClient *http.Client httpClient *http.Client
} }
func (h *httpCurrentStateInternalAPI) QueryRoomsForUser(
ctx context.Context,
request *api.QueryRoomsForUserRequest,
response *api.QueryRoomsForUserResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomsForUser")
defer span.Finish()
apiURL := h.apiURL + QueryRoomsForUserPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpCurrentStateInternalAPI) QueryBulkStateContent( func (h *httpCurrentStateInternalAPI) QueryBulkStateContent(
ctx context.Context, ctx context.Context,
request *api.QueryBulkStateContentRequest, request *api.QueryBulkStateContentRequest,

View file

@ -25,19 +25,6 @@ import (
) )
func AddRoutes(internalAPIMux *mux.Router, intAPI api.CurrentStateInternalAPI) { func AddRoutes(internalAPIMux *mux.Router, intAPI api.CurrentStateInternalAPI) {
internalAPIMux.Handle(QueryRoomsForUserPath,
httputil.MakeInternalAPI("queryRoomsForUser", func(req *http.Request) util.JSONResponse {
request := api.QueryRoomsForUserRequest{}
response := api.QueryRoomsForUserResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.QueryRoomsForUser(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryBulkStateContentPath, internalAPIMux.Handle(QueryBulkStateContentPath,
httputil.MakeInternalAPI("queryBulkStateContent", func(req *http.Request) util.JSONResponse { httputil.MakeInternalAPI("queryBulkStateContent", func(req *http.Request) util.JSONResponse {
request := api.QueryBulkStateContentRequest{} request := api.QueryBulkStateContentRequest{}

View file

@ -20,12 +20,12 @@ import (
"fmt" "fmt"
"github.com/Shopify/sarama" "github.com/Shopify/sarama"
stateapi "github.com/matrix-org/dendrite/currentstateserver/api"
"github.com/matrix-org/dendrite/federationsender/queue" "github.com/matrix-org/dendrite/federationsender/queue"
"github.com/matrix-org/dendrite/federationsender/storage" "github.com/matrix-org/dendrite/federationsender/storage"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -36,7 +36,7 @@ type KeyChangeConsumer struct {
db storage.Database db storage.Database
queues *queue.OutgoingQueues queues *queue.OutgoingQueues
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
stateAPI stateapi.CurrentStateInternalAPI rsAPI roomserverAPI.RoomserverInternalAPI
} }
// NewKeyChangeConsumer creates a new KeyChangeConsumer. Call Start() to begin consuming from key servers. // NewKeyChangeConsumer creates a new KeyChangeConsumer. Call Start() to begin consuming from key servers.
@ -45,7 +45,7 @@ func NewKeyChangeConsumer(
kafkaConsumer sarama.Consumer, kafkaConsumer sarama.Consumer,
queues *queue.OutgoingQueues, queues *queue.OutgoingQueues,
store storage.Database, store storage.Database,
stateAPI stateapi.CurrentStateInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
) *KeyChangeConsumer { ) *KeyChangeConsumer {
c := &KeyChangeConsumer{ c := &KeyChangeConsumer{
consumer: &internal.ContinualConsumer{ consumer: &internal.ContinualConsumer{
@ -57,7 +57,7 @@ func NewKeyChangeConsumer(
queues: queues, queues: queues,
db: store, db: store,
serverName: cfg.Matrix.ServerName, serverName: cfg.Matrix.ServerName,
stateAPI: stateAPI, rsAPI: rsAPI,
} }
c.consumer.ProcessMessage = c.onMessage c.consumer.ProcessMessage = c.onMessage
@ -92,8 +92,8 @@ func (t *KeyChangeConsumer) onMessage(msg *sarama.ConsumerMessage) error {
return nil return nil
} }
var queryRes stateapi.QueryRoomsForUserResponse var queryRes roomserverAPI.QueryRoomsForUserResponse
err = t.stateAPI.QueryRoomsForUser(context.Background(), &stateapi.QueryRoomsForUserRequest{ err = t.rsAPI.QueryRoomsForUser(context.Background(), &roomserverAPI.QueryRoomsForUserRequest{
UserID: m.UserID, UserID: m.UserID,
WantMembership: "join", WantMembership: "join",
}, &queryRes) }, &queryRes)

View file

@ -16,7 +16,6 @@ package federationsender
import ( import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
stateapi "github.com/matrix-org/dendrite/currentstateserver/api"
"github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/federationsender/consumers" "github.com/matrix-org/dendrite/federationsender/consumers"
"github.com/matrix-org/dendrite/federationsender/internal" "github.com/matrix-org/dendrite/federationsender/internal"
@ -42,7 +41,6 @@ func NewInternalAPI(
base *setup.BaseDendrite, base *setup.BaseDendrite,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
stateAPI stateapi.CurrentStateInternalAPI,
keyRing *gomatrixserverlib.KeyRing, keyRing *gomatrixserverlib.KeyRing,
) api.FederationSenderInternalAPI { ) api.FederationSenderInternalAPI {
cfg := &base.Cfg.FederationSender cfg := &base.Cfg.FederationSender
@ -82,7 +80,7 @@ func NewInternalAPI(
logrus.WithError(err).Panic("failed to start typing server consumer") logrus.WithError(err).Panic("failed to start typing server consumer")
} }
keyConsumer := consumers.NewKeyChangeConsumer( keyConsumer := consumers.NewKeyChangeConsumer(
&base.Cfg.KeyServer, base.KafkaConsumer, queues, federationSenderDB, stateAPI, &base.Cfg.KeyServer, base.KafkaConsumer, queues, federationSenderDB, rsAPI,
) )
if err := keyConsumer.Start(); err != nil { if err := keyConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start key server consumer") logrus.WithError(err).Panic("failed to start key server consumer")

View file

@ -112,11 +112,6 @@ type mockStateAPI struct {
rsAPI *mockRoomserverAPI rsAPI *mockRoomserverAPI
} }
// QueryRoomsForUser retrieves a list of room IDs matching the given query.
func (s *mockStateAPI) QueryRoomsForUser(ctx context.Context, req *stateapi.QueryRoomsForUserRequest, res *stateapi.QueryRoomsForUserResponse) error {
return nil
}
// QueryBulkStateContent does a bulk query for state event content in the given rooms. // QueryBulkStateContent does a bulk query for state event content in the given rooms.
func (s *mockStateAPI) QueryBulkStateContent(ctx context.Context, req *stateapi.QueryBulkStateContentRequest, res *stateapi.QueryBulkStateContentResponse) error { func (s *mockStateAPI) QueryBulkStateContent(ctx context.Context, req *stateapi.QueryBulkStateContentRequest, res *stateapi.QueryBulkStateContentResponse) error {
var res2 api.QueryBulkStateContentResponse var res2 api.QueryBulkStateContentResponse

View file

@ -460,3 +460,8 @@ If user leaves room, remote user changes device and rejoins we see update in /sy
Can search public room list Can search public room list
Can get remote public room list Can get remote public room list
Asking for a remote rooms list, but supplying the local server's name, returns the local rooms list Asking for a remote rooms list, but supplying the local server's name, returns the local rooms list
After changing password, can't log in with old password
After changing password, can log in with new password
After changing password, existing session still works
After changing password, different sessions can optionally be kept
After changing password, a different session no longer works by default

View file

@ -26,6 +26,7 @@ import (
type UserInternalAPI interface { type UserInternalAPI interface {
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
@ -63,6 +64,10 @@ type PerformDeviceDeletionRequest struct {
UserID string UserID string
// The devices to delete. An empty slice means delete all devices. // The devices to delete. An empty slice means delete all devices.
DeviceIDs []string DeviceIDs []string
// The requesting device ID to exclude from deletion. This is needed
// so that a password change doesn't cause that client to be logged
// out. Only specify when DeviceIDs is empty.
ExceptDeviceID string
} }
type PerformDeviceDeletionResponse struct { type PerformDeviceDeletionResponse struct {
@ -165,6 +170,18 @@ type PerformAccountCreationResponse struct {
Account *Account Account *Account
} }
// PerformAccountCreationRequest is the request for PerformAccountCreation
type PerformPasswordUpdateRequest struct {
Localpart string // Required: The localpart for this account.
Password string // Required: The new password to set.
}
// PerformAccountCreationResponse is the response for PerformAccountCreation
type PerformPasswordUpdateResponse struct {
PasswordUpdated bool
Account *Account
}
// PerformDeviceCreationRequest is the request for PerformDeviceCreation // PerformDeviceCreationRequest is the request for PerformDeviceCreation
type PerformDeviceCreationRequest struct { type PerformDeviceCreationRequest struct {
Localpart string Localpart string

View file

@ -98,6 +98,15 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
res.Account = acc res.Account = acc
return nil return nil
} }
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
return err
}
res.PasswordUpdated = true
return nil
}
func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error { func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error {
util.GetLogger(ctx).WithFields(logrus.Fields{ util.GetLogger(ctx).WithFields(logrus.Fields{
"localpart": req.Localpart, "localpart": req.Localpart,
@ -126,7 +135,7 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deletedDeviceIDs := req.DeviceIDs deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 { if len(req.DeviceIDs) == 0 {
var devices []api.Device var devices []api.Device
devices, err = a.DeviceDB.RemoveAllDevices(ctx, local) devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
for _, d := range devices { for _, d := range devices {
deletedDeviceIDs = append(deletedDeviceIDs, d.ID) deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
} }

View file

@ -30,6 +30,7 @@ const (
PerformDeviceCreationPath = "/userapi/performDeviceCreation" PerformDeviceCreationPath = "/userapi/performDeviceCreation"
PerformAccountCreationPath = "/userapi/performAccountCreation" PerformAccountCreationPath = "/userapi/performAccountCreation"
PerformPasswordUpdatePath = "/userapi/performPasswordUpdate"
PerformDeviceDeletionPath = "/userapi/performDeviceDeletion" PerformDeviceDeletionPath = "/userapi/performDeviceDeletion"
PerformDeviceUpdatePath = "/userapi/performDeviceUpdate" PerformDeviceUpdatePath = "/userapi/performDeviceUpdate"
@ -81,6 +82,18 @@ func (h *httpUserInternalAPI) PerformAccountCreation(
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
func (h *httpUserInternalAPI) PerformPasswordUpdate(
ctx context.Context,
request *api.PerformPasswordUpdateRequest,
response *api.PerformPasswordUpdateResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPasswordUpdate")
defer span.Finish()
apiURL := h.apiURL + PerformPasswordUpdatePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) PerformDeviceCreation( func (h *httpUserInternalAPI) PerformDeviceCreation(
ctx context.Context, ctx context.Context,
request *api.PerformDeviceCreationRequest, request *api.PerformDeviceCreationRequest,

View file

@ -39,6 +39,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(PerformAccountCreationPath,
httputil.MakeInternalAPI("performPasswordUpdate", func(req *http.Request) util.JSONResponse {
request := api.PerformPasswordUpdateRequest{}
response := api.PerformPasswordUpdateResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformPasswordUpdate(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformDeviceCreationPath, internalAPIMux.Handle(PerformDeviceCreationPath,
httputil.MakeInternalAPI("performDeviceCreation", func(req *http.Request) util.JSONResponse { httputil.MakeInternalAPI("performDeviceCreation", func(req *http.Request) util.JSONResponse {
request := api.PerformDeviceCreationRequest{} request := api.PerformDeviceCreationRequest{}

View file

@ -28,6 +28,7 @@ type Database interface {
internal.PartitionStorer internal.PartitionStorer
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
SetDisplayName(ctx context.Context, localpart string, displayName string) error SetDisplayName(ctx context.Context, localpart string, displayName string) error
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile

View file

@ -47,6 +47,9 @@ CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
const insertAccountSQL = "" + const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
const selectAccountByLocalpartSQL = "" + const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
@ -56,10 +59,9 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" + const selectNewNumericLocalpartSQL = "" +
"SELECT nextval('numeric_username_seq')" "SELECT nextval('numeric_username_seq')"
// TODO: Update password
type accountsStatements struct { type accountsStatements struct {
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt
updatePasswordStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt
@ -74,6 +76,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
return return
} }
if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil {
return
}
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
return return
} }
@ -114,6 +119,13 @@ func (s *accountsStatements) insertAccount(
}, nil }, nil
} }
func (s *accountsStatements) updatePassword(
ctx context.Context, localpart, passwordHash string,
) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
return
}
func (s *accountsStatements) selectPasswordHash( func (s *accountsStatements) selectPasswordHash(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (hash string, err error) { ) (hash string, err error) {

View file

@ -112,6 +112,17 @@ func (d *Database) SetDisplayName(
return d.profiles.setDisplayName(ctx, localpart, displayName) return d.profiles.setDisplayName(ctx, localpart, displayName)
} }
// SetPassword sets the account password to the given hash.
func (d *Database) SetPassword(
ctx context.Context, localpart, plaintextPassword string,
) error {
hash, err := hashPassword(plaintextPassword)
if err != nil {
return err
}
return d.accounts.updatePassword(ctx, localpart, hash)
}
// CreateGuestAccount makes a new guest account and creates an empty profile // CreateGuestAccount makes a new guest account and creates an empty profile
// for this account. // for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) {

View file

@ -45,6 +45,9 @@ CREATE TABLE IF NOT EXISTS account_accounts (
const insertAccountSQL = "" + const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
const selectAccountByLocalpartSQL = "" + const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
@ -54,11 +57,10 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" + const selectNewNumericLocalpartSQL = "" +
"SELECT COUNT(localpart) FROM account_accounts" "SELECT COUNT(localpart) FROM account_accounts"
// TODO: Update password
type accountsStatements struct { type accountsStatements struct {
db *sql.DB db *sql.DB
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt
updatePasswordStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt selectNewNumericLocalpartStmt *sql.Stmt
@ -75,6 +77,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
return return
} }
if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil {
return
}
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
return return
} }
@ -115,6 +120,13 @@ func (s *accountsStatements) insertAccount(
}, nil }, nil
} }
func (s *accountsStatements) updatePassword(
ctx context.Context, localpart, passwordHash string,
) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
return
}
func (s *accountsStatements) selectPasswordHash( func (s *accountsStatements) selectPasswordHash(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (hash string, err error) { ) (hash string, err error) {

View file

@ -126,6 +126,18 @@ func (d *Database) SetDisplayName(
}) })
} }
// SetPassword sets the account password to the given hash.
func (d *Database) SetPassword(
ctx context.Context, localpart, plaintextPassword string,
) error {
hash, err := hashPassword(plaintextPassword)
if err != nil {
return err
}
err = d.accounts.updatePassword(ctx, localpart, hash)
return err
}
// CreateGuestAccount makes a new guest account and creates an empty profile // CreateGuestAccount makes a new guest account and creates an empty profile
// for this account. // for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) {

View file

@ -36,5 +36,5 @@ type Database interface {
RemoveDevice(ctx context.Context, deviceID, localpart string) error RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted. // RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart string) (devices []api.Device, err error) RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
} }

View file

@ -70,7 +70,7 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" + const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1" "SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2"
const updateDeviceNameSQL = "" + const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
@ -79,7 +79,7 @@ const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" + const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1" "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" + const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
@ -179,10 +179,10 @@ func (s *devicesStatements) deleteDevices(
// deleteDevicesByLocalpart removes all devices for the // deleteDevicesByLocalpart removes all devices for the
// given user localpart. // given user localpart.
func (s *devicesStatements) deleteDevicesByLocalpart( func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart) _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
return err return err
} }
@ -251,10 +251,10 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
} }
func (s *devicesStatements) selectDevicesByLocalpart( func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) { ) ([]api.Device, error) {
devices := []api.Device{} devices := []api.Device{}
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart) rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
if err != nil { if err != nil {
return devices, err return devices, err

View file

@ -68,7 +68,7 @@ func (d *Database) GetDeviceByID(
func (d *Database) GetDevicesByLocalpart( func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) ([]api.Device, error) { ) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart) return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
} }
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@ -175,14 +175,14 @@ func (d *Database) RemoveDevices(
// database matching the given user ID localpart. // database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error. // If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices( func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string, ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) { ) (devices []api.Device, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart) devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil { if err != nil {
return err return err
} }
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err return err
} }
return nil return nil

View file

@ -59,7 +59,7 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" + const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1" "SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2"
const updateDeviceNameSQL = "" + const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
@ -68,7 +68,7 @@ const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2" "DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" + const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1" "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" + const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
@ -182,10 +182,10 @@ func (s *devicesStatements) deleteDevices(
} }
func (s *devicesStatements) deleteDevicesByLocalpart( func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart) _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
return err return err
} }
@ -231,10 +231,10 @@ func (s *devicesStatements) selectDeviceByID(
} }
func (s *devicesStatements) selectDevicesByLocalpart( func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) { ) ([]api.Device, error) {
devices := []api.Device{} devices := []api.Device{}
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart) rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
if err != nil { if err != nil {
return devices, err return devices, err

View file

@ -72,7 +72,7 @@ func (d *Database) GetDeviceByID(
func (d *Database) GetDevicesByLocalpart( func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) ([]api.Device, error) { ) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart) return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
} }
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@ -179,14 +179,14 @@ func (d *Database) RemoveDevices(
// database matching the given user ID localpart. // database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error. // If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices( func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string, ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) { ) (devices []api.Device, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart) devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil { if err != nil {
return err return err
} }
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err return err
} }
return nil return nil