Virtual hosting schema and logic changes (#2876)

Note that virtual users cannot federate correctly yet.
This commit is contained in:
Neil Alexander 2022-11-11 16:41:37 +00:00 committed by GitHub
parent e177e0ae73
commit 529df30b56
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
62 changed files with 1250 additions and 732 deletions

View file

@ -32,6 +32,7 @@ import (
"github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
) )
// AddInternalRoutes registers HTTP handlers for internal API calls // AddInternalRoutes registers HTTP handlers for internal API calls
@ -74,7 +75,7 @@ func NewInternalAPI(
// events to be sent out. // events to be sent out.
for _, appservice := range base.Cfg.Derived.ApplicationServices { for _, appservice := range base.Cfg.Derived.ApplicationServices {
// Create bot account for this AS if it doesn't already exist // Create bot account for this AS if it doesn't already exist
if err := generateAppServiceAccount(userAPI, appservice); err != nil { if err := generateAppServiceAccount(userAPI, appservice, base.Cfg.Global.ServerName); err != nil {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"appservice": appservice.ID, "appservice": appservice.ID,
}).WithError(err).Panicf("failed to generate bot account for appservice") }).WithError(err).Panicf("failed to generate bot account for appservice")
@ -101,11 +102,13 @@ func NewInternalAPI(
func generateAppServiceAccount( func generateAppServiceAccount(
userAPI userapi.AppserviceUserAPI, userAPI userapi.AppserviceUserAPI,
as config.ApplicationService, as config.ApplicationService,
serverName gomatrixserverlib.ServerName,
) error { ) error {
var accRes userapi.PerformAccountCreationResponse var accRes userapi.PerformAccountCreationResponse
err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{ err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeAppService, AccountType: userapi.AccountTypeAppService,
Localpart: as.SenderLocalpart, Localpart: as.SenderLocalpart,
ServerName: serverName,
AppServiceID: as.ID, AppServiceID: as.ID,
OnConflict: userapi.ConflictUpdate, OnConflict: userapi.ConflictUpdate,
}, &accRes) }, &accRes)
@ -115,6 +118,7 @@ func generateAppServiceAccount(
var devRes userapi.PerformDeviceCreationResponse var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{ err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{
Localpart: as.SenderLocalpart, Localpart: as.SenderLocalpart,
ServerName: serverName,
AccessToken: as.ASToken, AccessToken: as.ASToken,
DeviceID: &as.SenderLocalpart, DeviceID: &as.SenderLocalpart,
DeviceDisplayName: &as.SenderLocalpart, DeviceDisplayName: &as.SenderLocalpart,

View file

@ -61,7 +61,7 @@ func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte)
func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) { func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) {
r := req.(*PasswordRequest) r := req.(*PasswordRequest)
username := strings.ToLower(r.Username()) username := r.Username()
if username == "" { if username == "" {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusUnauthorized, Code: http.StatusUnauthorized,
@ -74,32 +74,43 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login,
JSON: jsonerror.BadJSON("A password must be supplied."), JSON: jsonerror.BadJSON("A password must be supplied."),
} }
} }
localpart, _, err := userutil.ParseUsernameParam(username, t.Config.Matrix) localpart, domain, err := userutil.ParseUsernameParam(username, t.Config.Matrix)
if err != nil { if err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusUnauthorized, Code: http.StatusUnauthorized,
JSON: jsonerror.InvalidUsername(err.Error()), JSON: jsonerror.InvalidUsername(err.Error()),
} }
} }
if !t.Config.Matrix.IsLocalServerName(domain) {
return nil, &util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: jsonerror.InvalidUsername("The server name is not known."),
}
}
// Squash username to all lowercase letters // Squash username to all lowercase letters
res := &api.QueryAccountByPasswordResponse{} res := &api.QueryAccountByPasswordResponse{}
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: strings.ToLower(localpart), PlaintextPassword: r.Password}, res) err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
Localpart: strings.ToLower(localpart),
ServerName: domain,
PlaintextPassword: r.Password,
}, res)
if err != nil { if err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("unable to fetch account by password"), JSON: jsonerror.Unknown("Unable to fetch account by password."),
} }
} }
if !res.Exists { if !res.Exists {
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
PlaintextPassword: r.Password, PlaintextPassword: r.Password,
}, res) }, res)
if err != nil { if err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: jsonerror.Unknown("unable to fetch account by password"), JSON: jsonerror.Unknown("Unable to fetch account by password."),
} }
} }
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows

View file

@ -102,6 +102,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
serverName := cfg.Matrix.ServerName
localpart, ok := vars["localpart"] localpart, ok := vars["localpart"]
if !ok { if !ok {
return util.JSONResponse{ return util.JSONResponse{
@ -109,6 +110,9 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
JSON: jsonerror.MissingArgument("Expecting user localpart."), JSON: jsonerror.MissingArgument("Expecting user localpart."),
} }
} }
if l, s, err := gomatrixserverlib.SplitID('@', localpart); err == nil {
localpart, serverName = l, s
}
request := struct { request := struct {
Password string `json:"password"` Password string `json:"password"`
}{} }{}
@ -126,6 +130,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
} }
updateReq := &userapi.PerformPasswordUpdateRequest{ updateReq := &userapi.PerformPasswordUpdateRequest{
Localpart: localpart, Localpart: localpart,
ServerName: serverName,
Password: request.Password, Password: request.Password,
LogoutDevices: true, LogoutDevices: true,
} }

View file

@ -100,6 +100,7 @@ func completeAuth(
DeviceID: login.DeviceID, DeviceID: login.DeviceID,
AccessToken: token, AccessToken: token,
Localpart: localpart, Localpart: localpart,
ServerName: serverName,
IPAddr: ipAddr, IPAddr: ipAddr,
UserAgent: userAgent, UserAgent: userAgent,
}, &performRes) }, &performRes)

View file

@ -40,13 +40,14 @@ func GetNotifications(
} }
var queryRes userapi.QueryNotificationsResponse var queryRes userapi.QueryNotificationsResponse
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{ err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
From: req.URL.Query().Get("from"), From: req.URL.Query().Get("from"),
Limit: int(limit), Limit: int(limit),
Only: req.URL.Query().Get("only"), Only: req.URL.Query().Get("only"),

View file

@ -86,7 +86,7 @@ func Password(
} }
// Get the local part. // Get the local part.
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -95,6 +95,7 @@ func Password(
// Ask the user API to perform the password change. // Ask the user API to perform the password change.
passwordReq := &api.PerformPasswordUpdateRequest{ passwordReq := &api.PerformPasswordUpdateRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
Password: r.NewPassword, Password: r.NewPassword,
} }
passwordRes := &api.PerformPasswordUpdateResponse{} passwordRes := &api.PerformPasswordUpdateResponse{}
@ -123,6 +124,7 @@ func Password(
pushersReq := &api.PerformPusherDeletionRequest{ pushersReq := &api.PerformPusherDeletionRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
SessionID: device.SessionID, SessionID: device.SessionID,
} }
if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil { if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil {

View file

@ -31,13 +31,14 @@ func GetPushers(
userAPI userapi.ClientUserAPI, userAPI userapi.ClientUserAPI,
) util.JSONResponse { ) util.JSONResponse {
var queryRes userapi.QueryPushersResponse var queryRes userapi.QueryPushersResponse
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{ err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
}, &queryRes) }, &queryRes)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed") util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed")
@ -59,7 +60,7 @@ func SetPusher(
req *http.Request, device *userapi.Device, req *http.Request, device *userapi.Device,
userAPI userapi.ClientUserAPI, userAPI userapi.ClientUserAPI,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -93,6 +94,7 @@ func SetPusher(
} }
body.Localpart = localpart body.Localpart = localpart
body.ServerName = domain
body.SessionID = device.SessionID body.SessionID = device.SessionID
err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{}) err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{})
if err != nil { if err != nil {

View file

@ -588,12 +588,15 @@ func Register(
} }
// Auto generate a numeric username if r.Username is empty // Auto generate a numeric username if r.Username is empty
if r.Username == "" { if r.Username == "" {
res := &userapi.QueryNumericLocalpartResponse{} nreq := &userapi.QueryNumericLocalpartRequest{
if err := userAPI.QueryNumericLocalpart(req.Context(), res); err != nil { ServerName: cfg.Matrix.ServerName, // TODO: might not be right
}
nres := &userapi.QueryNumericLocalpartResponse{}
if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed") util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
r.Username = strconv.FormatInt(res.ID, 10) r.Username = strconv.FormatInt(nres.ID, 10)
} }
// Is this an appservice registration? It will be if the access // Is this an appservice registration? It will be if the access
@ -676,6 +679,7 @@ func handleGuestRegistration(
var devRes userapi.PerformDeviceCreationResponse var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{ err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{
Localpart: res.Account.Localpart, Localpart: res.Account.Localpart,
ServerName: res.Account.ServerName,
DeviceDisplayName: r.InitialDisplayName, DeviceDisplayName: r.InitialDisplayName,
AccessToken: token, AccessToken: token,
IPAddr: req.RemoteAddr, IPAddr: req.RemoteAddr,

View file

@ -157,7 +157,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
dendriteAdminRouter.Handle("/admin/resetPassword/{localpart}", dendriteAdminRouter.Handle("/admin/resetPassword/{userID}",
httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return AdminResetPassword(req, cfg, device, userAPI) return AdminResetPassword(req, cfg, device, userAPI)
}), }),

View file

@ -286,6 +286,7 @@ func getSenderDevice(
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
AccountType: userapi.AccountTypeUser, AccountType: userapi.AccountTypeUser,
Localpart: cfg.Matrix.ServerNotices.LocalPart, Localpart: cfg.Matrix.ServerNotices.LocalPart,
ServerName: cfg.Matrix.ServerName,
OnConflict: userapi.ConflictUpdate, OnConflict: userapi.ConflictUpdate,
}, &accRes) }, &accRes)
if err != nil { if err != nil {
@ -296,6 +297,7 @@ func getSenderDevice(
avatarRes := &userapi.PerformSetAvatarURLResponse{} avatarRes := &userapi.PerformSetAvatarURLResponse{}
if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{ if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart, Localpart: cfg.Matrix.ServerNotices.LocalPart,
ServerName: cfg.Matrix.ServerName,
AvatarURL: cfg.Matrix.ServerNotices.AvatarURL, AvatarURL: cfg.Matrix.ServerNotices.AvatarURL,
}, avatarRes); err != nil { }, avatarRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed") util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed")
@ -308,6 +310,7 @@ func getSenderDevice(
displayNameRes := &userapi.PerformUpdateDisplayNameResponse{} displayNameRes := &userapi.PerformUpdateDisplayNameResponse{}
if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{ if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart, Localpart: cfg.Matrix.ServerNotices.LocalPart,
ServerName: cfg.Matrix.ServerName,
DisplayName: cfg.Matrix.ServerNotices.DisplayName, DisplayName: cfg.Matrix.ServerNotices.DisplayName,
}, displayNameRes); err != nil { }, displayNameRes); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed") util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed")
@ -353,6 +356,7 @@ func getSenderDevice(
var devRes userapi.PerformDeviceCreationResponse var devRes userapi.PerformDeviceCreationResponse
err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{
Localpart: cfg.Matrix.ServerNotices.LocalPart, Localpart: cfg.Matrix.ServerNotices.LocalPart,
ServerName: cfg.Matrix.ServerName,
DeviceDisplayName: &cfg.Matrix.ServerNotices.LocalPart, DeviceDisplayName: &cfg.Matrix.ServerNotices.LocalPart,
AccessToken: token, AccessToken: token,
NoDeviceListUpdate: true, NoDeviceListUpdate: true,

View file

@ -136,7 +136,7 @@ func CheckAndSave3PIDAssociation(
} }
// Save the association in the database // Save the association in the database
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -145,6 +145,7 @@ func CheckAndSave3PIDAssociation(
if err = threePIDAPI.PerformSaveThreePIDAssociation(req.Context(), &api.PerformSaveThreePIDAssociationRequest{ if err = threePIDAPI.PerformSaveThreePIDAssociation(req.Context(), &api.PerformSaveThreePIDAssociationRequest{
ThreePID: address, ThreePID: address,
Localpart: localpart, Localpart: localpart,
ServerName: domain,
Medium: medium, Medium: medium,
}, &struct{}{}); err != nil { }, &struct{}{}); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.PerformSaveThreePIDAssociation failed") util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.PerformSaveThreePIDAssociation failed")
@ -161,7 +162,7 @@ func CheckAndSave3PIDAssociation(
func GetAssociated3PIDs( func GetAssociated3PIDs(
req *http.Request, threepidAPI api.ClientUserAPI, device *api.Device, req *http.Request, threepidAPI api.ClientUserAPI, device *api.Device,
) util.JSONResponse { ) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
@ -170,6 +171,7 @@ func GetAssociated3PIDs(
res := &api.QueryThreePIDsForLocalpartResponse{} res := &api.QueryThreePIDsForLocalpartResponse{}
err = threepidAPI.QueryThreePIDsForLocalpart(req.Context(), &api.QueryThreePIDsForLocalpartRequest{ err = threepidAPI.QueryThreePIDsForLocalpart(req.Context(), &api.QueryThreePIDsForLocalpartRequest{
Localpart: localpart, Localpart: localpart,
ServerName: domain,
}, res) }, res)
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.QueryThreePIDsForLocalpart failed") util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.QueryThreePIDsForLocalpart failed")

View file

@ -120,15 +120,23 @@ func NewInternalAPI(
js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
signingInfo := map[gomatrixserverlib.ServerName]*queue.SigningInfo{}
for _, serverName := range append(
[]gomatrixserverlib.ServerName{base.Cfg.Global.ServerName},
base.Cfg.Global.SecondaryServerNames...,
) {
signingInfo[serverName] = &queue.SigningInfo{
KeyID: cfg.Matrix.KeyID,
PrivateKey: cfg.Matrix.PrivateKey,
ServerName: serverName,
}
}
queues := queue.NewOutgoingQueues( queues := queue.NewOutgoingQueues(
federationDB, base.ProcessContext, federationDB, base.ProcessContext,
cfg.Matrix.DisableFederation, cfg.Matrix.DisableFederation,
cfg.Matrix.ServerName, federation, rsAPI, &stats, cfg.Matrix.ServerName, federation, rsAPI, &stats,
&queue.SigningInfo{ signingInfo,
KeyID: cfg.Matrix.KeyID,
PrivateKey: cfg.Matrix.PrivateKey,
ServerName: cfg.Matrix.ServerName,
},
) )
rsConsumer := consumers.NewOutputRoomEventConsumer( rsConsumer := consumers.NewOutputRoomEventConsumer(

View file

@ -137,7 +137,7 @@ func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err
} }
// Get the keys and JSON-ify them. // Get the keys and JSON-ify them.
keys := routing.LocalKeys(s.config) keys := routing.LocalKeys(s.config, gomatrixserverlib.ServerName(req.Host))
body, err := json.MarshalIndent(keys.JSON, "", " ") body, err := json.MarshalIndent(keys.JSON, "", " ")
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -50,7 +50,7 @@ type destinationQueue struct {
queues *OutgoingQueues queues *OutgoingQueues
db storage.Database db storage.Database
process *process.ProcessContext process *process.ProcessContext
signing *SigningInfo signing map[gomatrixserverlib.ServerName]*SigningInfo
rsAPI api.FederationRoomserverAPI rsAPI api.FederationRoomserverAPI
client fedapi.FederationClient // federation client client fedapi.FederationClient // federation client
origin gomatrixserverlib.ServerName // origin of requests origin gomatrixserverlib.ServerName // origin of requests

View file

@ -46,7 +46,7 @@ type OutgoingQueues struct {
origin gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName
client fedapi.FederationClient client fedapi.FederationClient
statistics *statistics.Statistics statistics *statistics.Statistics
signing *SigningInfo signing map[gomatrixserverlib.ServerName]*SigningInfo
queuesMutex sync.Mutex // protects the below queuesMutex sync.Mutex // protects the below
queues map[gomatrixserverlib.ServerName]*destinationQueue queues map[gomatrixserverlib.ServerName]*destinationQueue
} }
@ -91,7 +91,7 @@ func NewOutgoingQueues(
client fedapi.FederationClient, client fedapi.FederationClient,
rsAPI api.FederationRoomserverAPI, rsAPI api.FederationRoomserverAPI,
statistics *statistics.Statistics, statistics *statistics.Statistics,
signing *SigningInfo, signing map[gomatrixserverlib.ServerName]*SigningInfo,
) *OutgoingQueues { ) *OutgoingQueues {
queues := &OutgoingQueues{ queues := &OutgoingQueues{
disabled: disabled, disabled: disabled,
@ -199,11 +199,10 @@ func (oqs *OutgoingQueues) SendEvent(
log.Trace("Federation is disabled, not sending event") log.Trace("Federation is disabled, not sending event")
return nil return nil
} }
if origin != oqs.origin { if _, ok := oqs.signing[origin]; !ok {
// TODO: Support virtual hosting; gh issue #577.
return fmt.Errorf( return fmt.Errorf(
"sendevent: unexpected server to send as: got %q expected %q", "sendevent: unexpected server to send as %q",
origin, oqs.origin, origin,
) )
} }
@ -214,7 +213,9 @@ func (oqs *OutgoingQueues) SendEvent(
destmap[d] = struct{}{} destmap[d] = struct{}{}
} }
delete(destmap, oqs.origin) delete(destmap, oqs.origin)
delete(destmap, oqs.signing.ServerName) for local := range oqs.signing {
delete(destmap, local)
}
// Check if any of the destinations are prohibited by server ACLs. // Check if any of the destinations are prohibited by server ACLs.
for destination := range destmap { for destination := range destmap {
@ -288,11 +289,10 @@ func (oqs *OutgoingQueues) SendEDU(
log.Trace("Federation is disabled, not sending EDU") log.Trace("Federation is disabled, not sending EDU")
return nil return nil
} }
if origin != oqs.origin { if _, ok := oqs.signing[origin]; !ok {
// TODO: Support virtual hosting; gh issue #577.
return fmt.Errorf( return fmt.Errorf(
"sendevent: unexpected server to send as: got %q expected %q", "sendevent: unexpected server to send as %q",
origin, oqs.origin, origin,
) )
} }
@ -303,7 +303,9 @@ func (oqs *OutgoingQueues) SendEDU(
destmap[d] = struct{}{} destmap[d] = struct{}{}
} }
delete(destmap, oqs.origin) delete(destmap, oqs.origin)
delete(destmap, oqs.signing.ServerName) for local := range oqs.signing {
delete(destmap, local)
}
// There is absolutely no guarantee that the EDU will have a room_id // There is absolutely no guarantee that the EDU will have a room_id
// field, as it is not required by the spec. However, if it *does* // field, as it is not required by the spec. However, if it *does*

View file

@ -350,10 +350,12 @@ func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T
} }
rs := &stubFederationRoomServerAPI{} rs := &stubFederationRoomServerAPI{}
stats := statistics.NewStatistics(db, failuresUntilBlacklist) stats := statistics.NewStatistics(db, failuresUntilBlacklist)
signingInfo := &SigningInfo{ signingInfo := map[gomatrixserverlib.ServerName]*SigningInfo{
"localhost": {
KeyID: "ed21019:auto", KeyID: "ed21019:auto",
PrivateKey: test.PrivateKeyA, PrivateKey: test.PrivateKeyA,
ServerName: "localhost", ServerName: "localhost",
},
} }
queues := NewOutgoingQueues(db, processContext, false, "localhost", fc, rs, &stats, signingInfo) queues := NewOutgoingQueues(db, processContext, false, "localhost", fc, rs, &stats, signingInfo)

View file

@ -16,6 +16,7 @@ package routing
import ( import (
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"time" "time"
@ -134,18 +135,21 @@ func ClaimOneTimeKeys(
// LocalKeys returns the local keys for the server. // LocalKeys returns the local keys for the server.
// See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys // See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys
func LocalKeys(cfg *config.FederationAPI) util.JSONResponse { func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) util.JSONResponse {
keys, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)) keys, err := localKeys(cfg, serverName, time.Now().Add(cfg.Matrix.KeyValidityPeriod))
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
return util.JSONResponse{Code: http.StatusOK, JSON: keys} return util.JSONResponse{Code: http.StatusOK, JSON: keys}
} }
func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) { func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) {
var keys gomatrixserverlib.ServerKeys var keys gomatrixserverlib.ServerKeys
if !cfg.Matrix.IsLocalServerName(serverName) {
return nil, fmt.Errorf("server name not known")
}
keys.ServerName = cfg.Matrix.ServerName keys.ServerName = serverName
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(validUntil) keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(validUntil)
publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey) publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)
@ -172,7 +176,7 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver
} }
keys.Raw, err = gomatrixserverlib.SignJSON( keys.Raw, err = gomatrixserverlib.SignJSON(
string(cfg.Matrix.ServerName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, toSign, string(serverName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, toSign,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -186,6 +190,14 @@ func NotaryKeys(
fsAPI federationAPI.FederationInternalAPI, fsAPI federationAPI.FederationInternalAPI,
req *gomatrixserverlib.PublicKeyNotaryLookupRequest, req *gomatrixserverlib.PublicKeyNotaryLookupRequest,
) util.JSONResponse { ) util.JSONResponse {
serverName := gomatrixserverlib.ServerName(httpReq.Host) // TODO: this is not ideal
if !cfg.Matrix.IsLocalServerName(serverName) {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("Server name not known"),
}
}
if req == nil { if req == nil {
req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{} req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{}
if reqErr := clienthttputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil { if reqErr := clienthttputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil {
@ -201,7 +213,7 @@ func NotaryKeys(
for serverName, kidToCriteria := range req.ServerKeys { for serverName, kidToCriteria := range req.ServerKeys {
var keyList []gomatrixserverlib.ServerKeys var keyList []gomatrixserverlib.ServerKeys
if serverName == cfg.Matrix.ServerName { if serverName == cfg.Matrix.ServerName {
if k, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil { if k, err := localKeys(cfg, serverName, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil {
keyList = append(keyList, *k) keyList = append(keyList, *k)
} else { } else {
return util.ErrorResponse(err) return util.ErrorResponse(err)

View file

@ -74,7 +74,7 @@ func Setup(
} }
localKeys := httputil.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse { localKeys := httputil.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse {
return LocalKeys(cfg) return LocalKeys(cfg, gomatrixserverlib.ServerName(req.Host))
}) })
notaryKeys := httputil.MakeExternalAPI("notarykeys", func(req *http.Request) util.JSONResponse { notaryKeys := httputil.MakeExternalAPI("notarykeys", func(req *http.Request) util.JSONResponse {

View file

@ -33,12 +33,13 @@ import (
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/producers" "github.com/matrix-org/dendrite/keyserver/producers"
"github.com/matrix-org/dendrite/keyserver/storage" "github.com/matrix-org/dendrite/keyserver/storage"
"github.com/matrix-org/dendrite/setup/config"
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
) )
type KeyInternalAPI struct { type KeyInternalAPI struct {
DB storage.Database DB storage.Database
ThisServer gomatrixserverlib.ServerName Cfg *config.KeyServer
FedClient fedsenderapi.KeyserverFederationAPI FedClient fedsenderapi.KeyserverFederationAPI
UserAPI userapi.KeyserverUserAPI UserAPI userapi.KeyserverUserAPI
Producer *producers.KeyChange Producer *producers.KeyChange
@ -95,8 +96,11 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
nested[userID] = val nested[userID] = val
domainToDeviceKeys[string(serverName)] = nested domainToDeviceKeys[string(serverName)] = nested
} }
for domain, local := range domainToDeviceKeys {
if !a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
continue
}
// claim local keys // claim local keys
if local, ok := domainToDeviceKeys[string(a.ThisServer)]; ok {
keys, err := a.DB.ClaimKeys(ctx, local) keys, err := a.DB.ClaimKeys(ctx, local)
if err != nil { if err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
@ -117,7 +121,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON
} }
} }
delete(domainToDeviceKeys, string(a.ThisServer)) delete(domainToDeviceKeys, domain)
} }
if len(domainToDeviceKeys) > 0 { if len(domainToDeviceKeys) > 0 {
a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
@ -258,7 +262,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
} }
domain := string(serverName) domain := string(serverName)
// query local devices // query local devices
if serverName == a.ThisServer { if a.Cfg.Matrix.IsLocalServerName(serverName) {
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
if err != nil { if err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
@ -437,13 +441,13 @@ func (a *KeyInternalAPI) queryRemoteKeys(
domains := map[string]struct{}{} domains := map[string]struct{}{}
for domain := range domainToDeviceKeys { for domain := range domainToDeviceKeys {
if domain == string(a.ThisServer) { if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
continue continue
} }
domains[domain] = struct{}{} domains[domain] = struct{}{}
} }
for domain := range domainToCrossSigningKeys { for domain := range domainToCrossSigningKeys {
if domain == string(a.ThisServer) { if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
continue continue
} }
domains[domain] = struct{}{} domains[domain] = struct{}{}
@ -689,7 +693,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
if err != nil { if err != nil {
continue // ignore invalid users continue // ignore invalid users
} }
if serverName != a.ThisServer { if !a.Cfg.Matrix.IsLocalServerName(serverName) {
continue // ignore remote users continue // ignore remote users
} }
if len(key.KeyJSON) == 0 { if len(key.KeyJSON) == 0 {

View file

@ -54,7 +54,7 @@ func NewInternalAPI(
} }
ap := &internal.KeyInternalAPI{ ap := &internal.KeyInternalAPI{
DB: db, DB: db,
ThisServer: cfg.Matrix.ServerName, Cfg: cfg,
FedClient: fedClient, FedClient: fedClient,
Producer: keyChangeProducer, Producer: keyChangeProducer,
} }

View file

@ -78,7 +78,7 @@ type ClientUserAPI interface {
QueryAcccessTokenAPI QueryAcccessTokenAPI
LoginTokenInternalAPI LoginTokenInternalAPI
UserLoginAPI UserLoginAPI
QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
@ -336,6 +336,7 @@ type PerformAccountCreationResponse struct {
// PerformAccountCreationRequest is the request for PerformAccountCreation // PerformAccountCreationRequest is the request for PerformAccountCreation
type PerformPasswordUpdateRequest struct { type PerformPasswordUpdateRequest struct {
Localpart string // Required: The localpart for this account. Localpart string // Required: The localpart for this account.
ServerName gomatrixserverlib.ServerName // Required: The domain for this account.
Password string // Required: The new password to set. Password string // Required: The new password to set.
LogoutDevices bool // Optional: Whether to log out all user devices. LogoutDevices bool // Optional: Whether to log out all user devices.
} }
@ -519,6 +520,7 @@ const (
type QueryPushersRequest struct { type QueryPushersRequest struct {
Localpart string Localpart string
ServerName gomatrixserverlib.ServerName
} }
type QueryPushersResponse struct { type QueryPushersResponse struct {
@ -528,11 +530,13 @@ type QueryPushersResponse struct {
type PerformPusherSetRequest struct { type PerformPusherSetRequest struct {
Pusher // Anonymous field because that's how clientapi unmarshals it. Pusher // Anonymous field because that's how clientapi unmarshals it.
Localpart string Localpart string
ServerName gomatrixserverlib.ServerName
Append bool `json:"append"` Append bool `json:"append"`
} }
type PerformPusherDeletionRequest struct { type PerformPusherDeletionRequest struct {
Localpart string Localpart string
ServerName gomatrixserverlib.ServerName
SessionID int64 SessionID int64
} }
@ -572,6 +576,7 @@ type QueryPushRulesResponse struct {
type QueryNotificationsRequest struct { type QueryNotificationsRequest struct {
Localpart string `json:"localpart"` // Required. Localpart string `json:"localpart"` // Required.
ServerName gomatrixserverlib.ServerName `json:"server_name"` // Required.
From string `json:"from,omitempty"` From string `json:"from,omitempty"`
Limit int `json:"limit,omitempty"` Limit int `json:"limit,omitempty"`
Only string `json:"only,omitempty"` Only string `json:"only,omitempty"`
@ -601,12 +606,17 @@ type PerformSetAvatarURLResponse struct {
Changed bool `json:"changed"` Changed bool `json:"changed"`
} }
type QueryNumericLocalpartRequest struct {
ServerName gomatrixserverlib.ServerName
}
type QueryNumericLocalpartResponse struct { type QueryNumericLocalpartResponse struct {
ID int64 ID int64
} }
type QueryAccountAvailabilityRequest struct { type QueryAccountAvailabilityRequest struct {
Localpart string Localpart string
ServerName gomatrixserverlib.ServerName
} }
type QueryAccountAvailabilityResponse struct { type QueryAccountAvailabilityResponse struct {
@ -614,7 +624,9 @@ type QueryAccountAvailabilityResponse struct {
} }
type QueryAccountByPasswordRequest struct { type QueryAccountByPasswordRequest struct {
Localpart, PlaintextPassword string Localpart string
ServerName gomatrixserverlib.ServerName
PlaintextPassword string
} }
type QueryAccountByPasswordResponse struct { type QueryAccountByPasswordResponse struct {
@ -639,10 +651,12 @@ type QueryLocalpartForThreePIDRequest struct {
type QueryLocalpartForThreePIDResponse struct { type QueryLocalpartForThreePIDResponse struct {
Localpart string Localpart string
ServerName gomatrixserverlib.ServerName
} }
type QueryThreePIDsForLocalpartRequest struct { type QueryThreePIDsForLocalpartRequest struct {
Localpart string Localpart string
ServerName gomatrixserverlib.ServerName
} }
type QueryThreePIDsForLocalpartResponse struct { type QueryThreePIDsForLocalpartResponse struct {
@ -652,5 +666,8 @@ type QueryThreePIDsForLocalpartResponse struct {
type PerformForgetThreePIDRequest QueryLocalpartForThreePIDRequest type PerformForgetThreePIDRequest QueryLocalpartForThreePIDRequest
type PerformSaveThreePIDAssociationRequest struct { type PerformSaveThreePIDAssociationRequest struct {
ThreePID, Localpart, Medium string ThreePID string
Localpart string
ServerName gomatrixserverlib.ServerName
Medium string
} }

View file

@ -156,8 +156,8 @@ func (t *UserInternalAPITrace) SetAvatarURL(ctx context.Context, req *PerformSet
return err return err
} }
func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error { func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error {
err := t.Impl.QueryNumericLocalpart(ctx, res) err := t.Impl.QueryNumericLocalpart(ctx, req, res)
util.GetLogger(ctx).Infof("QueryNumericLocalpart req= res=%+v", js(res)) util.GetLogger(ctx).Infof("QueryNumericLocalpart req= res=%+v", js(res))
return err return err
} }

View file

@ -104,7 +104,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
return false return false
} }
updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true) updated, err := s.db.SetNotificationsRead(ctx, localpart, domain, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true)
if err != nil { if err != nil {
log.WithError(err).Error("userapi EDU consumer") log.WithError(err).Error("userapi EDU consumer")
return false return false
@ -118,7 +118,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
if !updated { if !updated {
return true return true
} }
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil { if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, domain, s.db); err != nil {
log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed") log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
return false return false
} }

View file

@ -192,25 +192,25 @@ func (s *OutputRoomEventConsumer) storeMessageStats(ctx context.Context, eventTy
func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoomID, newRoomID string, localMembers []*localMembership, roomSize int) error { func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoomID, newRoomID string, localMembers []*localMembership, roomSize int) error {
for _, membership := range localMembers { for _, membership := range localMembers {
// Copy any existing push rules from old -> new room // Copy any existing push rules from old -> new room
if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil { if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil {
return err return err
} }
// preserve m.direct room state // preserve m.direct room state
if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, roomSize); err != nil { if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain, roomSize); err != nil {
return err return err
} }
// copy existing m.tag entries, if any // copy existing m.tag entries, if any
if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil { if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil {
return err return err
} }
} }
return nil return nil
} }
func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string) error { func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string, serverName gomatrixserverlib.ServerName) error {
pushRules, err := s.db.QueryPushRules(ctx, localpart) pushRules, err := s.db.QueryPushRules(ctx, localpart, serverName)
if err != nil { if err != nil {
return fmt.Errorf("failed to query pushrules for user: %w", err) return fmt.Errorf("failed to query pushrules for user: %w", err)
} }
@ -229,7 +229,7 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID,
if err != nil { if err != nil {
return err return err
} }
if err = s.db.SaveAccountData(ctx, localpart, "", "m.push_rules", rules); err != nil { if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.push_rules", rules); err != nil {
return fmt.Errorf("failed to update pushrules: %w", err) return fmt.Errorf("failed to update pushrules: %w", err)
} }
} }
@ -237,13 +237,13 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID,
} }
// updateMDirect copies the "is_direct" flag from oldRoomID to newROomID // updateMDirect copies the "is_direct" flag from oldRoomID to newROomID
func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, roomSize int) error { func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName, roomSize int) error {
// this is most likely not a DM, so skip updating m.direct state // this is most likely not a DM, so skip updating m.direct state
if roomSize > 2 { if roomSize > 2 {
return nil return nil
} }
// Get direct message state // Get direct message state
directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, "", "m.direct") directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, serverName, "", "m.direct")
if err != nil { if err != nil {
return fmt.Errorf("failed to get m.direct from database: %w", err) return fmt.Errorf("failed to get m.direct from database: %w", err)
} }
@ -267,7 +267,7 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID,
if err != nil { if err != nil {
return true return true
} }
if err = s.db.SaveAccountData(ctx, localpart, "", "m.direct", data); err != nil { if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.direct", data); err != nil {
return true return true
} }
} }
@ -279,15 +279,15 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID,
return nil return nil
} }
func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string) error { func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName) error {
tag, err := s.db.GetAccountDataByType(ctx, localpart, oldRoomID, "m.tag") tag, err := s.db.GetAccountDataByType(ctx, localpart, serverName, oldRoomID, "m.tag")
if err != nil && !errors.Is(err, sql.ErrNoRows) { if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err return err
} }
if tag == nil { if tag == nil {
return nil return nil
} }
return s.db.SaveAccountData(ctx, localpart, newRoomID, "m.tag", tag) return s.db.SaveAccountData(ctx, localpart, serverName, newRoomID, "m.tag", tag)
} }
func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error { func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error {
@ -492,11 +492,11 @@ func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, er
func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error { func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error {
actions, err := s.evaluatePushRules(ctx, event, mem, roomSize) actions, err := s.evaluatePushRules(ctx, event, mem, roomSize)
if err != nil { if err != nil {
return err return fmt.Errorf("s.evaluatePushRules: %w", err)
} }
a, tweaks, err := pushrules.ActionsToTweaks(actions) a, tweaks, err := pushrules.ActionsToTweaks(actions)
if err != nil { if err != nil {
return err return fmt.Errorf("pushrules.ActionsToTweaks: %w", err)
} }
// TODO: support coalescing. // TODO: support coalescing.
if a != pushrules.NotifyAction && a != pushrules.CoalesceAction { if a != pushrules.NotifyAction && a != pushrules.CoalesceAction {
@ -508,9 +508,9 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
return nil return nil
} }
devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks) devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, mem.Domain, tweaks)
if err != nil { if err != nil {
return err return fmt.Errorf("s.localPushDevices: %w", err)
} }
n := &api.Notification{ n := &api.Notification{
@ -527,18 +527,18 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
RoomID: event.RoomID(), RoomID: event.RoomID(),
TS: gomatrixserverlib.AsTimestamp(time.Now()), TS: gomatrixserverlib.AsTimestamp(time.Now()),
} }
if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), streamPos, tweaks, n); err != nil { if err = s.db.InsertNotification(ctx, mem.Localpart, mem.Domain, event.EventID(), streamPos, tweaks, n); err != nil {
return err return fmt.Errorf("s.db.InsertNotification: %w", err)
} }
if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil { if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil {
return err return fmt.Errorf("s.syncProducer.GetAndSendNotificationData: %w", err)
} }
// We do this after InsertNotification. Thus, this should always return >=1. // We do this after InsertNotification. Thus, this should always return >=1.
userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications) userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, mem.Domain, tables.AllNotifications)
if err != nil { if err != nil {
return err return fmt.Errorf("s.db.GetNotificationCount: %w", err)
} }
log.WithFields(log.Fields{ log.WithFields(log.Fields{
@ -589,7 +589,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
} }
if len(rejected) > 0 { if len(rejected) > 0 {
s.deleteRejectedPushers(ctx, rejected, mem.Localpart) s.deleteRejectedPushers(ctx, rejected, mem.Localpart, mem.Domain)
} }
}() }()
@ -606,7 +606,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
} }
// Get accountdata to check if the event.Sender() is ignored by mem.LocalPart // Get accountdata to check if the event.Sender() is ignored by mem.LocalPart
data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, "", "m.ignored_user_list") data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, mem.Domain, "", "m.ignored_user_list")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -621,7 +621,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
return nil, fmt.Errorf("user %s is ignored", sender) return nil, fmt.Errorf("user %s is ignored", sender)
} }
} }
ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart) ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart, mem.Domain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -693,10 +693,10 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err
// localPushDevices pushes to the configured devices of a local // localPushDevices pushes to the configured devices of a local
// user. The map keys are [url][format]. // user. The map keys are [url][format].
func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db) pusherDevices, err := util.GetPushDevices(ctx, localpart, serverName, tweaks, s.db)
if err != nil { if err != nil {
return nil, "", err return nil, "", fmt.Errorf("util.GetPushDevices: %w", err)
} }
var profileTag string var profileTag string
@ -791,7 +791,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatri
} }
// deleteRejectedPushers deletes the pushers associated with the given devices. // deleteRejectedPushers deletes the pushers associated with the given devices.
func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) { func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string, serverName gomatrixserverlib.ServerName) {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"localpart": localpart, "localpart": localpart,
"app_id0": devices[0].AppID, "app_id0": devices[0].AppID,
@ -799,7 +799,7 @@ func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, dev
}).Warnf("Deleting pushers rejected by the HTTP push gateway") }).Warnf("Deleting pushers rejected by the HTTP push gateway")
for _, d := range devices { for _, d := range devices {
if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil { if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart, serverName); err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"localpart": localpart, "localpart": localpart,
}).WithError(err).Errorf("Unable to delete rejected pusher") }).WithError(err).Errorf("Unable to delete rejected pusher")

View file

@ -68,7 +68,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
if req.DataType == "" { if req.DataType == "" {
return fmt.Errorf("data type must not be empty") return fmt.Errorf("data type must not be empty")
} }
if err := a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData); err != nil { if err := a.DB.SaveAccountData(ctx, local, domain, req.RoomID, req.DataType, req.AccountData); err != nil {
util.GetLogger(ctx).WithError(err).Error("a.DB.SaveAccountData failed") util.GetLogger(ctx).WithError(err).Error("a.DB.SaveAccountData failed")
return fmt.Errorf("failed to save account data: %w", err) return fmt.Errorf("failed to save account data: %w", err)
} }
@ -108,7 +108,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
return nil return nil
} }
deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now()))) deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, domain, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now())))
if err != nil { if err != nil {
logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed") logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed")
return err return err
@ -124,7 +124,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
return nil return nil
} }
if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, a.DB); err != nil { if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, domain, a.DB); err != nil {
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed") logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed")
return err return err
} }
@ -175,8 +175,10 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
if serverName == "" { if serverName == "" {
serverName = a.Config.Matrix.ServerName serverName = a.Config.Matrix.ServerName
} }
// XXXX: Use the server name here if !a.Config.Matrix.IsLocalServerName(serverName) {
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) return fmt.Errorf("server name %s is not local", serverName)
}
acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType)
if err != nil { if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
switch req.OnConflict { switch req.OnConflict {
@ -215,8 +217,8 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return nil return nil
} }
if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, req.Localpart); err != nil {
return err return fmt.Errorf("a.DB.SetDisplayName: %w", err)
} }
postRegisterJoinRooms(a.Cfg, acc, a.RSAPI) postRegisterJoinRooms(a.Cfg, acc, a.RSAPI)
@ -227,11 +229,14 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
} }
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil { if !a.Config.Matrix.IsLocalServerName(req.ServerName) {
return fmt.Errorf("server name %s is not local", req.ServerName)
}
if err := a.DB.SetPassword(ctx, req.Localpart, req.ServerName, req.Password); err != nil {
return err return err
} }
if req.LogoutDevices { if req.LogoutDevices {
if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, ""); err != nil { if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, req.ServerName, ""); err != nil {
return err return err
} }
} }
@ -244,14 +249,15 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
if serverName == "" { if serverName == "" {
serverName = a.Config.Matrix.ServerName serverName = a.Config.Matrix.ServerName
} }
_ = serverName if !a.Config.Matrix.IsLocalServerName(serverName) {
// XXXX: Use the server name here return fmt.Errorf("server name %s is not local", serverName)
}
util.GetLogger(ctx).WithFields(logrus.Fields{ util.GetLogger(ctx).WithFields(logrus.Fields{
"localpart": req.Localpart, "localpart": req.Localpart,
"device_id": req.DeviceID, "device_id": req.DeviceID,
"display_name": req.DeviceDisplayName, "display_name": req.DeviceDisplayName,
}).Info("PerformDeviceCreation") }).Info("PerformDeviceCreation")
dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) dev, err := a.DB.CreateDevice(ctx, req.Localpart, serverName, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
if err != nil { if err != nil {
return err return err
} }
@ -276,12 +282,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deletedDeviceIDs := req.DeviceIDs deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 { if len(req.DeviceIDs) == 0 {
var devices []api.Device var devices []api.Device
devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID) devices, err = a.DB.RemoveAllDevices(ctx, local, domain, req.ExceptDeviceID)
for _, d := range devices { for _, d := range devices {
deletedDeviceIDs = append(deletedDeviceIDs, d.ID) deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
} }
} else { } else {
err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs) err = a.DB.RemoveDevices(ctx, local, domain, req.DeviceIDs)
} }
if err != nil { if err != nil {
return err return err
@ -335,23 +341,29 @@ func (a *UserInternalAPI) PerformLastSeenUpdate(
req *api.PerformLastSeenUpdateRequest, req *api.PerformLastSeenUpdateRequest,
res *api.PerformLastSeenUpdateResponse, res *api.PerformLastSeenUpdateResponse,
) error { ) error {
localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil { if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
} }
if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil { if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("server name %s is not local", domain)
}
if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, domain, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil {
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err) return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
} }
return nil return nil
} }
func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID) localpart, domain, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
return err return err
} }
dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID) if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("server name %s is not local", domain)
}
dev, err := a.DB.GetDeviceByID(ctx, localpart, domain, req.DeviceID)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
res.DeviceExists = false res.DeviceExists = false
return nil return nil
@ -366,7 +378,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
return nil return nil
} }
err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName) err = a.DB.UpdateDevice(ctx, localpart, domain, req.DeviceID, req.DisplayName)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
return err return err
@ -406,7 +418,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
if !a.Config.Matrix.IsLocalServerName(domain) { if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("cannot query profile of remote users (server name %s)", domain) return fmt.Errorf("cannot query profile of remote users (server name %s)", domain)
} }
prof, err := a.DB.GetProfileByLocalpart(ctx, local) prof, err := a.DB.GetProfileByLocalpart(ctx, local, domain)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil return nil
@ -457,7 +469,7 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
if !a.Config.Matrix.IsLocalServerName(domain) { if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("cannot query devices of remote users (server name %s)", domain) return fmt.Errorf("cannot query devices of remote users (server name %s)", domain)
} }
devs, err := a.DB.GetDevicesByLocalpart(ctx, local) devs, err := a.DB.GetDevicesByLocalpart(ctx, local, domain)
if err != nil { if err != nil {
return err return err
} }
@ -476,7 +488,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
} }
if req.DataType != "" { if req.DataType != "" {
var data json.RawMessage var data json.RawMessage
data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) data, err = a.DB.GetAccountDataByType(ctx, local, domain, req.RoomID, req.DataType)
if err != nil { if err != nil {
return err return err
} }
@ -494,7 +506,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
} }
return nil return nil
} }
global, rooms, err := a.DB.GetAccountData(ctx, local) global, rooms, err := a.DB.GetAccountData(ctx, local, domain)
if err != nil { if err != nil {
return err return err
} }
@ -527,7 +539,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
if !a.Config.Matrix.IsLocalServerName(domain) { if !a.Config.Matrix.IsLocalServerName(domain) {
return nil return nil
} }
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart) acc, err := a.DB.GetAccountByLocalpart(ctx, localPart, domain)
if err != nil { if err != nil {
return err return err
} }
@ -561,14 +573,14 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
AccountType: api.AccountTypeAppService, AccountType: api.AccountTypeAppService,
} }
localpart, _, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix) localpart, domain, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if localpart != "" { // AS is masquerading as another user if localpart != "" { // AS is masquerading as another user
// Verify that the user is registered // Verify that the user is registered
account, err := a.DB.GetAccountByLocalpart(ctx, localpart) account, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain)
// Verify that the account exists and either appServiceID matches or // Verify that the account exists and either appServiceID matches or
// it belongs to the appservice user namespaces // it belongs to the appservice user namespaces
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) { if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
@ -620,7 +632,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
return err return err
} }
err := a.DB.DeactivateAccount(ctx, req.Localpart) err := a.DB.DeactivateAccount(ctx, req.Localpart, serverName)
res.AccountDeactivated = err == nil res.AccountDeactivated = err == nil
return err return err
} }
@ -783,7 +795,7 @@ func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.Query
if req.Only == "highlight" { if req.Only == "highlight" {
filter = tables.HighlightNotifications filter = tables.HighlightNotifications
} }
notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, fromID, req.Limit, filter) notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, req.ServerName, fromID, req.Limit, filter)
if err != nil { if err != nil {
return err return err
} }
@ -811,23 +823,23 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform
} }
} }
if req.Pusher.Kind == "" { if req.Pusher.Kind == "" {
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart) return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart, req.ServerName)
} }
if req.Pusher.PushKeyTS == 0 { if req.Pusher.PushKeyTS == 0 {
req.Pusher.PushKeyTS = int64(time.Now().Unix()) req.Pusher.PushKeyTS = int64(time.Now().Unix())
} }
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart) return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart, req.ServerName)
} }
func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error { func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
pushers, err := a.DB.GetPushers(ctx, req.Localpart) pushers, err := a.DB.GetPushers(ctx, req.Localpart, req.ServerName)
if err != nil { if err != nil {
return err return err
} }
for i := range pushers { for i := range pushers {
logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID) logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID)
if pushers[i].SessionID != req.SessionID { if pushers[i].SessionID != req.SessionID {
err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart) err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart, req.ServerName)
if err != nil { if err != nil {
return err return err
} }
@ -838,7 +850,7 @@ func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.Pe
func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error { func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
var err error var err error
res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart) res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart, req.ServerName)
return err return err
} }
@ -864,11 +876,11 @@ func (a *UserInternalAPI) PerformPushRulesPut(
} }
func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error { func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID) localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil { if err != nil {
return fmt.Errorf("failed to split user ID %q for push rules", req.UserID) return fmt.Errorf("failed to split user ID %q for push rules", req.UserID)
} }
pushRules, err := a.DB.QueryPushRules(ctx, localpart) pushRules, err := a.DB.QueryPushRules(ctx, localpart, domain)
if err != nil { if err != nil {
return fmt.Errorf("failed to query push rules: %w", err) return fmt.Errorf("failed to query push rules: %w", err)
} }
@ -877,14 +889,14 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush
} }
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error { func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL) profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.ServerName, req.AvatarURL)
res.Profile = profile res.Profile = profile
res.Changed = changed res.Changed = changed
return err return err
} }
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error { func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, req *api.QueryNumericLocalpartRequest, res *api.QueryNumericLocalpartResponse) error {
id, err := a.DB.GetNewNumericLocalpart(ctx) id, err := a.DB.GetNewNumericLocalpart(ctx, req.ServerName)
if err != nil { if err != nil {
return err return err
} }
@ -894,12 +906,12 @@ func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.Qu
func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error { func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error {
var err error var err error
res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart) res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart, req.ServerName)
return err return err
} }
func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error { func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error {
acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.PlaintextPassword) acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.ServerName, req.PlaintextPassword)
switch err { switch err {
case sql.ErrNoRows: // user does not exist case sql.ErrNoRows: // user does not exist
return nil return nil
@ -915,23 +927,24 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q
} }
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error { func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error {
profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName) profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.ServerName, req.DisplayName)
res.Profile = profile res.Profile = profile
res.Changed = changed res.Changed = changed
return err return err
} }
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error { func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {
localpart, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium) localpart, domain, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium)
if err != nil { if err != nil {
return err return err
} }
res.Localpart = localpart res.Localpart = localpart
res.ServerName = domain
return nil return nil
} }
func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error { func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error {
r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart) r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart, req.ServerName)
if err != nil { if err != nil {
return err return err
} }
@ -944,7 +957,7 @@ func (a *UserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.Pe
} }
func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error { func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error {
return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.Medium) return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium)
} }
const pushRulesAccountDataType = "m.push_rules" const pushRulesAccountDataType = "m.push_rules"

View file

@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
if !a.Config.Matrix.IsLocalServerName(domain) { if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain) return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain)
} }
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil { if _, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain); err != nil {
res.Data = nil res.Data = nil
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil return nil

View file

@ -355,11 +355,12 @@ func (h *httpUserInternalAPI) SetAvatarURL(
func (h *httpUserInternalAPI) QueryNumericLocalpart( func (h *httpUserInternalAPI) QueryNumericLocalpart(
ctx context.Context, ctx context.Context,
request *api.QueryNumericLocalpartRequest,
response *api.QueryNumericLocalpartResponse, response *api.QueryNumericLocalpartResponse,
) error { ) error {
return httputil.CallInternalRPCAPI( return httputil.CallInternalRPCAPI(
"QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath, "QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath,
h.httpClient, ctx, &struct{}{}, response, h.httpClient, ctx, request, response,
) )
} }

View file

@ -15,12 +15,9 @@
package inthttp package inthttp
import ( import (
"net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
) )
// nolint: gocyclo // nolint: gocyclo
@ -152,15 +149,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL), httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL),
) )
// TODO: Look at the shape of this internalAPIMux.Handle(
internalAPIMux.Handle(QueryNumericLocalpartPath, QueryNumericLocalpartPath,
httputil.MakeInternalAPI("UserAPIQueryNumericLocalpart", func(req *http.Request) util.JSONResponse { httputil.MakeInternalRPCAPI("UserAPIQueryNumericLocalpart", s.QueryNumericLocalpart),
response := api.QueryNumericLocalpartResponse{}
if err := s.QueryNumericLocalpart(req.Context(), &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
) )
internalAPIMux.Handle( internalAPIMux.Handle(

View file

@ -61,12 +61,12 @@ func (p *SyncAPI) SendAccountData(userID string, data eventutil.AccountData) err
// GetAndSendNotificationData reads the database and sends data about unread // GetAndSendNotificationData reads the database and sends data about unread
// notifications to the Sync API server. // notifications to the Sync API server.
func (p *SyncAPI) GetAndSendNotificationData(ctx context.Context, userID, roomID string) error { func (p *SyncAPI) GetAndSendNotificationData(ctx context.Context, userID, roomID string) error {
localpart, _, err := gomatrixserverlib.SplitID('@', userID) localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
return err return err
} }
ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, roomID) ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, domain, roomID)
if err != nil { if err != nil {
return err return err
} }

View file

@ -29,40 +29,40 @@ import (
) )
type Profile interface { type Profile interface {
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) GetProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, bool, error) SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error)
SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, bool, error) SetDisplayName(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error)
} }
type Account interface { type Account interface {
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists. // account already exists, it will return nil, ErrUserExists.
CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error) CreateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) GetAccountByPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) (*api.Account, error)
GetNewNumericLocalpart(ctx context.Context) (int64, error) GetNewNumericLocalpart(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error)
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error)
DeactivateAccount(ctx context.Context, localpart string) (err error) DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error)
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error SetPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) error
} }
type AccountData interface { type AccountData interface {
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error SaveAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
// GetAccountDataByType returns account data matching a given // GetAccountDataByType returns account data matching a given
// localpart, room ID and type. // localpart, room ID and type.
// If no account data could be found, returns nil // If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval // Returns an error if there was an issue with the retrieval
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error) GetAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error)
QueryPushRules(ctx context.Context, localpart string) (*pushrules.AccountRuleSets, error) QueryPushRules(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*pushrules.AccountRuleSets, error)
} }
type Device interface { type Device interface {
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) GetDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) GetDevicesByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Device, error)
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
// CreateDevice makes a new device associated with the given user ID localpart. // CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked // If there is already a device with the same device ID for this user, that access token will be revoked
@ -70,12 +70,12 @@ type Device interface {
// an error will be returned. // an error will be returned.
// If no device ID is given one is generated. // If no device ID is given one is generated.
// Returns the device on success. // Returns the device on success.
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) CreateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error UpdateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error RemoveDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted. // RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) RemoveAllDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) (devices []api.Device, err error)
} }
type KeyBackup interface { type KeyBackup interface {
@ -107,26 +107,26 @@ type OpenID interface {
} }
type Pusher interface { type Pusher interface {
UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error UpsertPusher(ctx context.Context, p api.Pusher, localpart string, serverName gomatrixserverlib.ServerName) error
GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error) GetPushers(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error)
RemovePusher(ctx context.Context, appid, pushkey, localpart string) error RemovePusher(ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error
RemovePushers(ctx context.Context, appid, pushkey string) error RemovePushers(ctx context.Context, appid, pushkey string) error
} }
type ThreePID interface { type ThreePID interface {
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error) SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName gomatrixserverlib.ServerName, medium string) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error) GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error)
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error)
} }
type Notification interface { type Notification interface {
InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error)
SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, read bool) (affected bool, err error) SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, read bool) (affected bool, err error)
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error)
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error)
DeleteOldNotifications(ctx context.Context) error DeleteOldNotifications(ctx context.Context) error
} }

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
) )
const accountDataSchema = ` const accountDataSchema = `
@ -29,27 +30,28 @@ const accountDataSchema = `
CREATE TABLE IF NOT EXISTS userapi_account_datas ( CREATE TABLE IF NOT EXISTS userapi_account_datas (
-- The Matrix user ID localpart for this account -- The Matrix user ID localpart for this account
localpart TEXT NOT NULL, localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
-- The room ID for this data (empty string if not specific to a room) -- The room ID for this data (empty string if not specific to a room)
room_id TEXT, room_id TEXT,
-- The account data type -- The account data type
type TEXT NOT NULL, type TEXT NOT NULL,
-- The account data content -- The account data content
content TEXT NOT NULL, content TEXT NOT NULL
PRIMARY KEY(localpart, room_id, type)
); );
CREATE UNIQUE INDEX IF NOT EXISTS userapi_account_datas_idx ON userapi_account_datas(localpart, server_name, room_id, type);
` `
const insertAccountDataSQL = ` const insertAccountDataSQL = `
INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4) INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5)
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = EXCLUDED.content
` `
const selectAccountDataSQL = "" + const selectAccountDataSQL = "" +
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1" "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2"
const selectAccountDataByTypeSQL = "" + const selectAccountDataByTypeSQL = "" +
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3" "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4"
type accountDataStatements struct { type accountDataStatements struct {
insertAccountDataStmt *sql.Stmt insertAccountDataStmt *sql.Stmt
@ -71,21 +73,24 @@ func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
} }
func (s *accountDataStatements) InsertAccountData( func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
roomID, dataType string, content json.RawMessage,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt) stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content) _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content)
return return
} }
func (s *accountDataStatements) SelectAccountData( func (s *accountDataStatements) SelectAccountData(
ctx context.Context, localpart string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) ( ) (
/* global */ map[string]json.RawMessage, /* global */ map[string]json.RawMessage,
/* rooms */ map[string]map[string]json.RawMessage, /* rooms */ map[string]map[string]json.RawMessage,
error, error,
) { ) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -117,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData(
} }
func (s *accountDataStatements) SelectAccountDataByType( func (s *accountDataStatements) SelectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
roomID, dataType string,
) (data json.RawMessage, err error) { ) (data json.RawMessage, err error) {
var bytes []byte var bytes []byte
stmt := s.selectAccountDataByTypeStmt stmt := s.selectAccountDataByTypeStmt
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil { if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }

View file

@ -17,6 +17,7 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"time" "time"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -34,7 +35,8 @@ const accountsSchema = `
-- Stores data about accounts. -- Stores data about accounts.
CREATE TABLE IF NOT EXISTS userapi_accounts ( CREATE TABLE IF NOT EXISTS userapi_accounts (
-- The Matrix user ID localpart for this account -- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY, localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
-- When this account was first created, as a unix timestamp (ms resolution). -- When this account was first created, as a unix timestamp (ms resolution).
created_ts BIGINT NOT NULL, created_ts BIGINT NOT NULL,
-- The password hash for this account. Can be NULL if this is a passwordless account. -- The password hash for this account. Can be NULL if this is a passwordless account.
@ -48,25 +50,27 @@ CREATE TABLE IF NOT EXISTS userapi_accounts (
-- TODO: -- TODO:
-- upgraded_ts, devices, any email reset stuff? -- upgraded_ts, devices, any email reset stuff?
); );
CREATE UNIQUE INDEX IF NOT EXISTS userapi_accounts_idx ON userapi_accounts(localpart, server_name);
` `
const insertAccountSQL = "" + const insertAccountSQL = "" +
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" "INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)"
const updatePasswordSQL = "" + const updatePasswordSQL = "" +
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2" "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3"
const deactivateAccountSQL = "" + const deactivateAccountSQL = "" +
"UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1" "UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1 AND server_name = $2"
const selectAccountByLocalpartSQL = "" + const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1" "SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
const selectPasswordHashSQL = "" + const selectPasswordHashSQL = "" +
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = FALSE" "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE"
const selectNewNumericLocalpartSQL = "" + const selectNewNumericLocalpartSQL = "" +
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$'" "SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$' AND server_name = $1"
type accountsStatements struct { type accountsStatements struct {
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt
@ -117,59 +121,62 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) InsertAccount( func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) { ) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
var err error var err error
if accountType != api.AccountTypeAppService { if accountType != api.AccountTypeAppService {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) _, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType)
} else { } else {
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) _, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
} }
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("insertAccountStmt: %w", err)
} }
return &api.Account{ return &api.Account{
Localpart: localpart, Localpart: localpart,
UserID: userutil.MakeUserID(localpart, s.serverName), UserID: userutil.MakeUserID(localpart, serverName),
ServerName: s.serverName, ServerName: serverName,
AppServiceID: appserviceID, AppServiceID: appserviceID,
AccountType: accountType, AccountType: accountType,
}, nil }, nil
} }
func (s *accountsStatements) UpdatePassword( func (s *accountsStatements) UpdatePassword(
ctx context.Context, localpart, passwordHash string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
passwordHash string,
) (err error) { ) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName)
return return
} }
func (s *accountsStatements) DeactivateAccount( func (s *accountsStatements) DeactivateAccount(
ctx context.Context, localpart string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (err error) { ) (err error) {
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName)
return return
} }
func (s *accountsStatements) SelectPasswordHash( func (s *accountsStatements) SelectPasswordHash(
ctx context.Context, localpart string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (hash string, err error) { ) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash)
return return
} }
func (s *accountsStatements) SelectAccountByLocalpart( func (s *accountsStatements) SelectAccountByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (*api.Account, error) { ) (*api.Account, error) {
var appserviceIDPtr sql.NullString var appserviceIDPtr sql.NullString
var acc api.Account var acc api.Account
stmt := s.selectAccountByLocalpartStmt stmt := s.selectAccountByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db") log.WithError(err).Error("Unable to retrieve user from the db")
@ -180,19 +187,17 @@ func (s *accountsStatements) SelectAccountByLocalpart(
acc.AppServiceID = appserviceIDPtr.String acc.AppServiceID = appserviceIDPtr.String
} }
acc.UserID = userutil.MakeUserID(localpart, s.serverName) acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName)
acc.ServerName = s.serverName
return &acc, nil return &acc, nil
} }
func (s *accountsStatements) SelectNewNumericLocalpart( func (s *accountsStatements) SelectNewNumericLocalpart(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (id int64, err error) { ) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt stmt := s.selectNewNumericLocalpartStmt
if txn != nil { if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt) stmt = sqlutil.TxStmt(txn, stmt)
} }
err = stmt.QueryRowContext(ctx).Scan(&id) err = stmt.QueryRowContext(ctx, serverName).Scan(&id)
return id + 1, err return id + 1, err
} }

View file

@ -0,0 +1,81 @@
package deltas
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib"
)
var serverNamesTables = []string{
"userapi_accounts",
"userapi_account_datas",
"userapi_devices",
"userapi_notifications",
"userapi_openid_tokens",
"userapi_profiles",
"userapi_pushers",
"userapi_threepids",
}
// These tables have a PRIMARY KEY constraint which we need to drop so
// that we can recreate a new unique index that contains the server name.
// If the new key doesn't exist (i.e. the database was created before the
// table rename migration) we'll try to drop the old one instead.
var serverNamesDropPK = map[string]string{
"userapi_accounts": "account_accounts",
"userapi_account_datas": "account_data",
"userapi_profiles": "account_profiles",
}
// These indices are out of date so let's drop them. They will get recreated
// automatically.
var serverNamesDropIndex = []string{
"userapi_pusher_localpart_idx",
"userapi_pusher_app_id_pushkey_localpart_idx",
}
// I know what you're thinking: you're wondering "why doesn't this use $1
// and pass variadic parameters to ExecContext?" — the answer is because
// PostgreSQL doesn't expect the table name to be specified as a substituted
// argument in that way so it results in a syntax error in the query.
func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
for _, table := range serverNamesTables {
q := fmt.Sprintf(
"ALTER TABLE IF EXISTS %s ADD COLUMN IF NOT EXISTS server_name TEXT NOT NULL DEFAULT '';",
pq.QuoteIdentifier(table),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("add server name to %q error: %w", table, err)
}
}
for newTable, oldTable := range serverNamesDropPK {
q := fmt.Sprintf(
"ALTER TABLE IF EXISTS %s DROP CONSTRAINT IF EXISTS %s;",
pq.QuoteIdentifier(newTable), pq.QuoteIdentifier(newTable+"_pkey"),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("drop new PK from %q error: %w", newTable, err)
}
q = fmt.Sprintf(
"ALTER TABLE IF EXISTS %s DROP CONSTRAINT IF EXISTS %s;",
pq.QuoteIdentifier(newTable), pq.QuoteIdentifier(oldTable+"_pkey"),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("drop old PK from %q error: %w", newTable, err)
}
}
for _, index := range serverNamesDropIndex {
q := fmt.Sprintf(
"DROP INDEX IF EXISTS %s;",
pq.QuoteIdentifier(index),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("drop index %q error: %w", index, err)
}
}
return nil
}

View file

@ -0,0 +1,28 @@
package deltas
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib"
)
// I know what you're thinking: you're wondering "why doesn't this use $1
// and pass variadic parameters to ExecContext?" — the answer is because
// PostgreSQL doesn't expect the table name to be specified as a substituted
// argument in that way so it results in a syntax error in the query.
func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
for _, table := range serverNamesTables {
q := fmt.Sprintf(
"UPDATE %s SET server_name = %s WHERE server_name = '';",
pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("write server names to %q error: %w", table, err)
}
}
return nil
}

View file

@ -17,6 +17,7 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"time" "time"
"github.com/lib/pq" "github.com/lib/pq"
@ -50,6 +51,7 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
-- as it is smaller, makes it clearer that we only manage devices for our own users, and may make -- as it is smaller, makes it clearer that we only manage devices for our own users, and may make
-- migration to different domain names easier. -- migration to different domain names easier.
localpart TEXT NOT NULL, localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
-- When this devices was first recognised on the network, as a unix timestamp (ms resolution). -- When this devices was first recognised on the network, as a unix timestamp (ms resolution).
created_ts BIGINT NOT NULL, created_ts BIGINT NOT NULL,
-- The display name, human friendlier than device_id and updatable -- The display name, human friendlier than device_id and updatable
@ -65,39 +67,39 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
); );
-- Device IDs must be unique for a given user. -- Device IDs must be unique for a given user.
CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, device_id); CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, server_name, device_id);
` `
const insertDeviceSQL = "" + const insertDeviceSQL = "" +
"INSERT INTO userapi_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + "INSERT INTO userapi_devices(device_id, localpart, server_name, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" +
" RETURNING session_id" " RETURNING session_id"
const selectDeviceByTokenSQL = "" + const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1" "SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" + const selectDeviceByIDSQL = "" +
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2" "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
const selectDevicesByLocalpartSQL = "" + const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
const updateDeviceNameSQL = "" + const updateDeviceNameSQL = "" +
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4"
const deleteDeviceSQL = "" + const deleteDeviceSQL = "" +
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2" "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3"
const deleteDevicesByLocalpartSQL = "" + const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2" "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3"
const deleteDevicesSQL = "" + const deleteDevicesSQL = "" +
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id = ANY($2)" "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = ANY($3)"
const selectDevicesByIDSQL = "" + const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" "SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
const updateDeviceLastSeen = "" + const updateDeviceLastSeen = "" +
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
type devicesStatements struct { type devicesStatements struct {
insertDeviceStmt *sql.Stmt insertDeviceStmt *sql.Stmt
@ -148,18 +150,19 @@ func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName
// Returns an error if the user already has a device with the given device ID. // Returns an error if the user already has a device with the given device ID.
// Returns the device on success. // Returns the device on success.
func (s *devicesStatements) InsertDevice( func (s *devicesStatements) InsertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, ctx context.Context, txn *sql.Tx, id string,
displayName *string, ipAddr, userAgent string, localpart string, serverName gomatrixserverlib.ServerName,
accessToken string, displayName *string, ipAddr, userAgent string,
) (*api.Device, error) { ) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64 var sessionID int64
stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil { if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
return nil, err return nil, fmt.Errorf("insertDeviceStmt: %w", err)
} }
return &api.Device{ return &api.Device{
ID: id, ID: id,
UserID: userutil.MakeUserID(localpart, s.serverName), UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken, AccessToken: accessToken,
SessionID: sessionID, SessionID: sessionID,
LastSeenTS: createdTimeMS, LastSeenTS: createdTimeMS,
@ -170,38 +173,45 @@ func (s *devicesStatements) InsertDevice(
// deleteDevice removes a single device by id and user localpart. // deleteDevice removes a single device by id and user localpart.
func (s *devicesStatements) DeleteDevice( func (s *devicesStatements) DeleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string, ctx context.Context, txn *sql.Tx, id string,
localpart string, serverName gomatrixserverlib.ServerName,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
_, err := stmt.ExecContext(ctx, id, localpart) _, err := stmt.ExecContext(ctx, id, localpart, serverName)
return err return err
} }
// deleteDevices removes a single or multiple devices by ids and user localpart. // deleteDevices removes a single or multiple devices by ids and user localpart.
// Returns an error if the execution failed. // Returns an error if the execution failed.
func (s *devicesStatements) DeleteDevices( func (s *devicesStatements) DeleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
devices []string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt) stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt)
_, err := stmt.ExecContext(ctx, localpart, pq.Array(devices)) _, err := stmt.ExecContext(ctx, localpart, serverName, pq.Array(devices))
return err return err
} }
// deleteDevicesByLocalpart removes all devices for the // deleteDevicesByLocalpart removes all devices for the
// given user localpart. // given user localpart.
func (s *devicesStatements) DeleteDevicesByLocalpart( func (s *devicesStatements) DeleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
exceptDeviceID string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID) _, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID)
return err return err
} }
func (s *devicesStatements) UpdateDeviceName( func (s *devicesStatements) UpdateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
deviceID string, displayName *string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) _, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID)
return err return err
} }
@ -210,10 +220,11 @@ func (s *devicesStatements) SelectDeviceByToken(
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device var dev api.Device
var localpart string var localpart string
var serverName gomatrixserverlib.ServerName
stmt := s.selectDeviceByTokenStmt stmt := s.selectDeviceByTokenStmt
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart) err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName)
if err == nil { if err == nil {
dev.UserID = userutil.MakeUserID(localpart, s.serverName) dev.UserID = userutil.MakeUserID(localpart, serverName)
dev.AccessToken = accessToken dev.AccessToken = accessToken
} }
return &dev, err return &dev, err
@ -222,16 +233,18 @@ func (s *devicesStatements) SelectDeviceByToken(
// selectDeviceByID retrieves a device from the database with the given user // selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID // localpart and deviceID
func (s *devicesStatements) SelectDeviceByID( func (s *devicesStatements) SelectDeviceByID(
ctx context.Context, localpart, deviceID string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
deviceID string,
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device var dev api.Device
var displayName, ip sql.NullString var displayName, ip sql.NullString
var lastseenTS sql.NullInt64 var lastseenTS sql.NullInt64
stmt := s.selectDeviceByIDStmt stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip) err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip)
if err == nil { if err == nil {
dev.ID = deviceID dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName) dev.UserID = userutil.MakeUserID(localpart, serverName)
if displayName.Valid { if displayName.Valid {
dev.DisplayName = displayName.String dev.DisplayName = displayName.String
} }
@ -254,10 +267,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
var devices []api.Device var devices []api.Device
var dev api.Device var dev api.Device
var localpart string var localpart string
var serverName gomatrixserverlib.ServerName
var lastseents sql.NullInt64 var lastseents sql.NullInt64
var displayName sql.NullString var displayName sql.NullString
for rows.Next() { for rows.Next() {
if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil { if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil {
return nil, err return nil, err
} }
if displayName.Valid { if displayName.Valid {
@ -266,17 +280,19 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
if lastseents.Valid { if lastseents.Valid {
dev.LastSeenTS = lastseents.Int64 dev.LastSeenTS = lastseents.Int64
} }
dev.UserID = userutil.MakeUserID(localpart, s.serverName) dev.UserID = userutil.MakeUserID(localpart, serverName)
devices = append(devices, dev) devices = append(devices, dev)
} }
return devices, rows.Err() return devices, rows.Err()
} }
func (s *devicesStatements) SelectDevicesByLocalpart( func (s *devicesStatements) SelectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
exceptDeviceID string,
) ([]api.Device, error) { ) ([]api.Device, error) {
devices := []api.Device{} devices := []api.Device{}
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID) rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID)
if err != nil { if err != nil {
return devices, err return devices, err
@ -307,16 +323,16 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
dev.UserAgent = useragent.String dev.UserAgent = useragent.String
} }
dev.UserID = userutil.MakeUserID(localpart, s.serverName) dev.UserID = userutil.MakeUserID(localpart, serverName)
devices = append(devices, dev) devices = append(devices, dev)
} }
return devices, rows.Err() return devices, rows.Err()
} }
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error { func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
lastSeenTs := time.Now().UnixNano() / 1000000 lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID) _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID)
return err return err
} }

View file

@ -43,6 +43,7 @@ const notificationSchema = `
CREATE TABLE IF NOT EXISTS userapi_notifications ( CREATE TABLE IF NOT EXISTS userapi_notifications (
id BIGSERIAL PRIMARY KEY, id BIGSERIAL PRIMARY KEY,
localpart TEXT NOT NULL, localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
stream_pos BIGINT NOT NULL, stream_pos BIGINT NOT NULL,
@ -52,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications (
read BOOLEAN NOT NULL DEFAULT FALSE read BOOLEAN NOT NULL DEFAULT FALSE
); );
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id); CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id); CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id); CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id);
` `
const insertNotificationSQL = "" + const insertNotificationSQL = "" +
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" "INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
const deleteNotificationsUpToSQL = "" + const deleteNotificationsUpToSQL = "" +
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" "DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4"
const updateNotificationReadSQL = "" + const updateNotificationReadSQL = "" +
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1"
const selectNotificationSQL = "" + const selectNotificationSQL = "" +
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" +
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + "(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" +
") AND NOT read ORDER BY localpart, id LIMIT $4" ") AND NOT read ORDER BY localpart, id LIMIT $5"
const selectNotificationCountSQL = "" + const selectNotificationCountSQL = "" +
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" + "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" +
"(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
") AND NOT read" ") AND NOT read"
const selectRoomNotificationCountsSQL = "" + const selectRoomNotificationCountsSQL = "" +
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " + "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
"WHERE localpart = $1 AND room_id = $2 AND NOT read" "WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read"
const cleanNotificationsSQL = "" + const cleanNotificationsSQL = "" +
"DELETE FROM userapi_notifications WHERE" + "DELETE FROM userapi_notifications WHERE" +
@ -111,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
} }
// Insert inserts a notification into the database. // Insert inserts a notification into the database.
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error { func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS roomID, tsMS := n.RoomID, n.TS
nn := *n nn := *n
// Clears out fields that have their own columns to (1) shrink the // Clears out fields that have their own columns to (1) shrink the
@ -122,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
if err != nil { if err != nil {
return err return err
} }
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs))
return err return err
} }
// DeleteUpTo deletes all previous notifications, up to and including the event. // DeleteUpTo deletes all previous notifications, up to and including the event.
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) { func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -141,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
} }
// UpdateRead updates the "read" value for an event. // UpdateRead updates the "read" value for an event.
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) { func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -154,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l
return nrows > 0, nil return nrows > 0, nil
} }
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit) rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@ -197,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
return notifs, maxID, rows.Err() return notifs, maxID, rows.Err()
} }
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) { func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) {
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count) err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count)
return return
} }
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) { func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) {
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight) err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight)
return return
} }

View file

@ -3,6 +3,7 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
@ -18,16 +19,17 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
token TEXT NOT NULL PRIMARY KEY, token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account -- The Matrix user ID for this account
localpart TEXT NOT NULL, localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
-- When the token expires, as a unix timestamp (ms resolution). -- When the token expires, as a unix timestamp (ms resolution).
token_expires_at_ms BIGINT NOT NULL token_expires_at_ms BIGINT NOT NULL
); );
` `
const insertOpenIDTokenSQL = "" + const insertOpenIDTokenSQL = "" +
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" "INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
const selectOpenIDTokenSQL = "" + const selectOpenIDTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" "SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
type openIDTokenStatements struct { type openIDTokenStatements struct {
insertTokenStmt *sql.Stmt insertTokenStmt *sql.Stmt
@ -54,11 +56,11 @@ func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
func (s *openIDTokenStatements) InsertOpenIDToken( func (s *openIDTokenStatements) InsertOpenIDToken(
ctx context.Context, ctx context.Context,
txn *sql.Tx, txn *sql.Tx,
token, localpart string, token, localpart string, serverName gomatrixserverlib.ServerName,
expiresAtMS int64, expiresAtMS int64,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS) _, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS)
return return
} }
@ -69,10 +71,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
token string, token string,
) (*api.OpenIDTokenAttributes, error) { ) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes var openIDTokenAttrs api.OpenIDTokenAttributes
var localpart string
var serverName gomatrixserverlib.ServerName
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
&openIDTokenAttrs.UserID, &localpart, &serverName,
&openIDTokenAttrs.ExpiresAtMS, &openIDTokenAttrs.ExpiresAtMS,
) )
openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve token from the db") log.WithError(err).Error("Unable to retrieve token from the db")

View file

@ -23,42 +23,46 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
) )
const profilesSchema = ` const profilesSchema = `
-- Stores data about accounts profiles. -- Stores data about accounts profiles.
CREATE TABLE IF NOT EXISTS userapi_profiles ( CREATE TABLE IF NOT EXISTS userapi_profiles (
-- The Matrix user ID localpart for this account -- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY, localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
-- The display name for this account -- The display name for this account
display_name TEXT, display_name TEXT,
-- The URL of the avatar for this account -- The URL of the avatar for this account
avatar_url TEXT avatar_url TEXT
); );
CREATE UNIQUE INDEX IF NOT EXISTS userapi_profiles_idx ON userapi_profiles(localpart, server_name);
` `
const insertProfileSQL = "" + const insertProfileSQL = "" +
"INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" "INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)"
const selectProfileByLocalpartSQL = "" + const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2"
const setAvatarURLSQL = "" + const setAvatarURLSQL = "" +
"UPDATE userapi_profiles AS new" + "UPDATE userapi_profiles AS new" +
" SET avatar_url = $1" + " SET avatar_url = $1" +
" FROM userapi_profiles AS old" + " FROM userapi_profiles AS old" +
" WHERE new.localpart = $2" + " WHERE new.localpart = $2 AND new.server_name = $3" +
" RETURNING new.display_name, old.avatar_url <> new.avatar_url" " RETURNING new.display_name, old.avatar_url <> new.avatar_url"
const setDisplayNameSQL = "" + const setDisplayNameSQL = "" +
"UPDATE userapi_profiles AS new" + "UPDATE userapi_profiles AS new" +
" SET display_name = $1" + " SET display_name = $1" +
" FROM userapi_profiles AS old" + " FROM userapi_profiles AS old" +
" WHERE new.localpart = $2" + " WHERE new.localpart = $2 AND new.server_name = $3" +
" RETURNING new.avatar_url, old.display_name <> new.display_name" " RETURNING new.avatar_url, old.display_name <> new.display_name"
const selectProfilesBySearchSQL = "" + const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
type profilesStatements struct { type profilesStatements struct {
serverNoticesLocalpart string serverNoticesLocalpart string
@ -87,18 +91,20 @@ func NewPostgresProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables
} }
func (s *profilesStatements) InsertProfile( func (s *profilesStatements) InsertProfile(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
) (err error) { ) (err error) {
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "")
return return
} }
func (s *profilesStatements) SelectProfileByLocalpart( func (s *profilesStatements) SelectProfileByLocalpart(
ctx context.Context, localpart string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
var profile authtypes.Profile var profile authtypes.Profile
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan( err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan(
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL, &profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -107,28 +113,34 @@ func (s *profilesStatements) SelectProfileByLocalpart(
} }
func (s *profilesStatements) SetAvatarURL( func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
avatarURL string,
) (*authtypes.Profile, bool, error) { ) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{ profile := &authtypes.Profile{
Localpart: localpart, Localpart: localpart,
ServerName: string(serverName),
AvatarURL: avatarURL, AvatarURL: avatarURL,
} }
var changed bool var changed bool
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName, &changed) err := stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName, &changed)
return profile, changed, err return profile, changed, err
} }
func (s *profilesStatements) SetDisplayName( func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
displayName string,
) (*authtypes.Profile, bool, error) { ) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{ profile := &authtypes.Profile{
Localpart: localpart, Localpart: localpart,
ServerName: string(serverName),
DisplayName: displayName, DisplayName: displayName,
} }
var changed bool var changed bool
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL, &changed) err := stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL, &changed)
return profile, changed, err return profile, changed, err
} }
@ -146,7 +158,7 @@ func (s *profilesStatements) SelectProfilesBySearch(
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
for rows.Next() { for rows.Next() {
var profile authtypes.Profile var profile authtypes.Profile
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil {
return nil, err return nil, err
} }
if profile.Localpart != s.serverNoticesLocalpart { if profile.Localpart != s.serverNoticesLocalpart {

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
) )
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
@ -33,6 +34,7 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
id BIGSERIAL PRIMARY KEY, id BIGSERIAL PRIMARY KEY,
-- The Matrix user ID localpart for this pusher -- The Matrix user ID localpart for this pusher
localpart TEXT NOT NULL, localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
session_id BIGINT DEFAULT NULL, session_id BIGINT DEFAULT NULL,
profile_tag TEXT, profile_tag TEXT,
kind TEXT NOT NULL, kind TEXT NOT NULL,
@ -49,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey); CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
-- For faster retrieving by localpart. -- For faster retrieving by localpart.
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart); CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name);
-- Pushkey must be unique for a given user and app. -- Pushkey must be unique for a given user and app.
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart); CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name);
` `
const insertPusherSQL = "" + const insertPusherSQL = "" +
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + "INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" +
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" "ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12"
const selectPushersSQL = "" + const selectPushersSQL = "" +
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1" "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2"
const deletePusherSQL = "" + const deletePusherSQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3" "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4"
const deletePushersByAppIdAndPushKeySQL = "" + const deletePushersByAppIdAndPushKeySQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2" "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
@ -95,18 +97,19 @@ type pushersStatements struct {
// Returns nil error success. // Returns nil error success.
func (s *pushersStatements) InsertPusher( func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64, ctx context.Context, txn *sql.Tx, session_id int64,
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data,
localpart string, serverName gomatrixserverlib.ServerName,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
logrus.Debugf("Created pusher %d", session_id)
return err return err
} }
func (s *pushersStatements) SelectPushers( func (s *pushersStatements) SelectPushers(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
) ([]api.Pusher, error) { ) ([]api.Pusher, error) {
pushers := []api.Pusher{} pushers := []api.Pusher{}
rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart) rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart, serverName)
if err != nil { if err != nil {
return pushers, err return pushers, err
@ -143,9 +146,10 @@ func (s *pushersStatements) SelectPushers(
// deletePusher removes a single pusher by pushkey and user localpart. // deletePusher removes a single pusher by pushkey and user localpart.
func (s *pushersStatements) DeletePusher( func (s *pushersStatements) DeletePusher(
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, ctx context.Context, txn *sql.Tx, appid, pushkey,
localpart string, serverName gomatrixserverlib.ServerName,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart) _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName)
return err return err
} }

View file

@ -15,6 +15,8 @@
package postgres package postgres
import ( import (
"context"
"database/sql"
"fmt" "fmt"
"time" "time"
@ -43,18 +45,24 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
Up: deltas.UpRenameTables, Up: deltas.UpRenameTables,
Down: deltas.DownRenameTables, Down: deltas.DownRenameTables,
}) })
m.AddMigrations(sqlutil.Migration{
Version: "userapi: server names",
Up: func(ctx context.Context, txn *sql.Tx) error {
return deltas.UpServerNames(ctx, txn, serverName)
},
})
if err = m.Up(base.Context()); err != nil { if err = m.Up(base.Context()); err != nil {
return nil, err return nil, err
} }
accountDataTable, err := NewPostgresAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
}
accountsTable, err := NewPostgresAccountsTable(db, serverName) accountsTable, err := NewPostgresAccountsTable(db, serverName)
if err != nil { if err != nil {
return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err) return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err)
} }
accountDataTable, err := NewPostgresAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
}
devicesTable, err := NewPostgresDevicesTable(db, serverName) devicesTable, err := NewPostgresDevicesTable(db, serverName)
if err != nil { if err != nil {
return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err) return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err)
@ -95,6 +103,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil { if err != nil {
return nil, fmt.Errorf("NewPostgresStatsTable: %w", err) return nil, fmt.Errorf("NewPostgresStatsTable: %w", err)
} }
m = sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "userapi: server names populate",
Up: func(ctx context.Context, txn *sql.Tx) error {
return deltas.UpServerNamesPopulate(ctx, txn, serverName)
},
})
if err = m.Up(base.Context()); err != nil {
return nil, err
}
return &shared.Database{ return &shared.Database{
AccountDatas: accountDataTable, AccountDatas: accountDataTable,
Accounts: accountsTable, Accounts: accountsTable,

View file

@ -20,6 +20,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
) )
@ -33,21 +34,22 @@ CREATE TABLE IF NOT EXISTS userapi_threepids (
medium TEXT NOT NULL DEFAULT 'email', medium TEXT NOT NULL DEFAULT 'email',
-- The localpart of the Matrix user ID associated to this 3PID -- The localpart of the Matrix user ID associated to this 3PID
localpart TEXT NOT NULL, localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
PRIMARY KEY(threepid, medium) PRIMARY KEY(threepid, medium)
); );
CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart); CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart, server_name);
` `
const selectLocalpartForThreePIDSQL = "" + const selectLocalpartForThreePIDSQL = "" +
"SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2" "SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
const selectThreePIDsForLocalpartSQL = "" + const selectThreePIDsForLocalpartSQL = "" +
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1" "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2"
const insertThreePIDSQL = "" + const insertThreePIDSQL = "" +
"INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)" "INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)"
const deleteThreePIDSQL = "" + const deleteThreePIDSQL = "" +
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2" "DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
@ -75,19 +77,20 @@ func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
func (s *threepidStatements) SelectLocalpartForThreePID( func (s *threepidStatements) SelectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string, ctx context.Context, txn *sql.Tx, threepid string, medium string,
) (localpart string, err error) { ) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", "", nil
} }
return return
} }
func (s *threepidStatements) SelectThreePIDsForLocalpart( func (s *threepidStatements) SelectThreePIDsForLocalpart(
ctx context.Context, localpart string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) (threepids []authtypes.ThreePID, err error) { ) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName)
if err != nil { if err != nil {
return return
} }
@ -109,10 +112,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart(
} }
func (s *threepidStatements) InsertThreePID( func (s *threepidStatements) InsertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ctx context.Context, txn *sql.Tx, threepid, medium,
localpart string, serverName gomatrixserverlib.ServerName,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium, localpart) _, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName)
return return
} }

View file

@ -68,9 +68,10 @@ const (
// GetAccountByPassword returns the account associated with the given localpart and password. // GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart. // Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword( func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
plaintextPassword string,
) (*api.Account, error) { ) (*api.Account, error) {
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart) hash, err := d.Accounts.SelectPasswordHash(ctx, localpart, serverName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -80,24 +81,27 @@ func (d *Database) GetAccountByPassword(
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
return nil, err return nil, err
} }
return d.Accounts.SelectAccountByLocalpart(ctx, localpart) return d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
} }
// GetProfileByLocalpart returns the profile associated with the given localpart. // GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart. // Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart( func (d *Database) GetProfileByLocalpart(
ctx context.Context, localpart string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
return d.Profiles.SelectProfileByLocalpart(ctx, localpart) return d.Profiles.SelectProfileByLocalpart(ctx, localpart, serverName)
} }
// SetAvatarURL updates the avatar URL of the profile associated with the given // SetAvatarURL updates the avatar URL of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query // localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL( func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
avatarURL string,
) (profile *authtypes.Profile, changed bool, err error) { ) (profile *authtypes.Profile, changed bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, serverName, avatarURL)
return err return err
}) })
return return
@ -106,10 +110,12 @@ func (d *Database) SetAvatarURL(
// SetDisplayName updates the display name of the profile associated with the given // SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query // localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName( func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
displayName string,
) (profile *authtypes.Profile, changed bool, err error) { ) (profile *authtypes.Profile, changed bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, serverName, displayName)
return err return err
}) })
return return
@ -117,14 +123,15 @@ func (d *Database) SetDisplayName(
// SetPassword sets the account password to the given hash. // SetPassword sets the account password to the given hash.
func (d *Database) SetPassword( func (d *Database) SetPassword(
ctx context.Context, localpart, plaintextPassword string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
plaintextPassword string,
) error { ) error {
hash, err := d.hashPassword(plaintextPassword) hash, err := d.hashPassword(plaintextPassword)
if err != nil { if err != nil {
return err return err
} }
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
return d.Accounts.UpdatePassword(ctx, localpart, hash) return d.Accounts.UpdatePassword(ctx, localpart, serverName, hash)
}) })
} }
@ -132,21 +139,22 @@ func (d *Database) SetPassword(
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists. // account already exists, it will return nil, ErrUserExists.
func (d *Database) CreateAccount( func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
plaintextPassword, appserviceID string, accountType api.AccountType,
) (acc *api.Account, err error) { ) (acc *api.Account, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// For guest accounts, we create a new numeric local part // For guest accounts, we create a new numeric local part
if accountType == api.AccountTypeGuest { if accountType == api.AccountTypeGuest {
var numLocalpart int64 var numLocalpart int64
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn) numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn, serverName)
if err != nil { if err != nil {
return err return fmt.Errorf("d.Accounts.SelectNewNumericLocalpart: %w", err)
} }
localpart = strconv.FormatInt(numLocalpart, 10) localpart = strconv.FormatInt(numLocalpart, 10)
plaintextPassword = "" plaintextPassword = ""
appserviceID = "" appserviceID = ""
} }
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType) acc, err = d.createAccount(ctx, txn, localpart, serverName, plaintextPassword, appserviceID, accountType)
return err return err
}) })
return return
@ -155,7 +163,9 @@ func (d *Database) CreateAccount(
// WARNING! This function assumes that the relevant mutexes have already // WARNING! This function assumes that the relevant mutexes have already
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). // been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
func (d *Database) createAccount( func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
plaintextPassword, appserviceID string, accountType api.AccountType,
) (*api.Account, error) { ) (*api.Account, error) {
var err error var err error
var account *api.Account var account *api.Account
@ -167,28 +177,28 @@ func (d *Database) createAccount(
return nil, err return nil, err
} }
} }
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil { if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, serverName, hash, appserviceID, accountType); err != nil {
return nil, sqlutil.ErrUserExists return nil, sqlutil.ErrUserExists
} }
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil { if err = d.Profiles.InsertProfile(ctx, txn, localpart, serverName); err != nil {
return nil, err return nil, fmt.Errorf("d.Profiles.InsertProfile: %w", err)
} }
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName) pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
prbs, err := json.Marshal(pushRuleSets) prbs, err := json.Marshal(pushRuleSets)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("json.Marshal: %w", err)
} }
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil { if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
return nil, err return nil, fmt.Errorf("d.AccountDatas.InsertAccountData: %w", err)
} }
return account, nil return account, nil
} }
func (d *Database) QueryPushRules( func (d *Database) QueryPushRules(
ctx context.Context, ctx context.Context,
localpart string, localpart string, serverName gomatrixserverlib.ServerName,
) (*pushrules.AccountRuleSets, error) { ) (*pushrules.AccountRuleSets, error) {
data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, "", "m.push_rules") data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, serverName, "", "m.push_rules")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -196,13 +206,13 @@ func (d *Database) QueryPushRules(
// If we didn't find any default push rules then we should just generate some // If we didn't find any default push rules then we should just generate some
// fresh ones. // fresh ones.
if len(data) == 0 { if len(data) == 0 {
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName) pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
prbs, err := json.Marshal(pushRuleSets) prbs, err := json.Marshal(pushRuleSets)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to marshal default push rules: %w", err) return nil, fmt.Errorf("failed to marshal default push rules: %w", err)
} }
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", prbs); dbErr != nil { if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", prbs); dbErr != nil {
return fmt.Errorf("failed to save default push rules: %w", dbErr) return fmt.Errorf("failed to save default push rules: %w", dbErr)
} }
return nil return nil
@ -225,22 +235,23 @@ func (d *Database) QueryPushRules(
// update the corresponding row with the new content // update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update // Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData( func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
roomID, dataType string, content json.RawMessage,
) error { ) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.AccountDatas.InsertAccountData(ctx, txn, localpart, roomID, dataType, content) return d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, roomID, dataType, content)
}) })
} }
// GetAccountData returns account data related to a given localpart // GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays // If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval // Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) ( func (d *Database) GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (
global map[string]json.RawMessage, global map[string]json.RawMessage,
rooms map[string]map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage,
err error, err error,
) { ) {
return d.AccountDatas.SelectAccountData(ctx, localpart) return d.AccountDatas.SelectAccountData(ctx, localpart, serverName)
} }
// GetAccountDataByType returns account data matching a given // GetAccountDataByType returns account data matching a given
@ -248,18 +259,19 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
// If no account data could be found, returns nil // If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval // Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType( func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
roomID, dataType string,
) (data json.RawMessage, err error) { ) (data json.RawMessage, err error) {
return d.AccountDatas.SelectAccountDataByType( return d.AccountDatas.SelectAccountDataByType(
ctx, localpart, roomID, dataType, ctx, localpart, serverName, roomID, dataType,
) )
} }
// GetNewNumericLocalpart generates and returns a new unused numeric localpart // GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart( func (d *Database) GetNewNumericLocalpart(
ctx context.Context, ctx context.Context, serverName gomatrixserverlib.ServerName,
) (int64, error) { ) (int64, error) {
return d.Accounts.SelectNewNumericLocalpart(ctx, nil) return d.Accounts.SelectNewNumericLocalpart(ctx, nil, serverName)
} }
func (d *Database) hashPassword(plaintext string) (hash string, err error) { func (d *Database) hashPassword(plaintext string) (hash string, err error) {
@ -276,10 +288,12 @@ var Err3PIDInUse = errors.New("this third-party identifier is already in use")
// If the third-party identifier is already part of an association, returns Err3PIDInUse. // If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation( func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string, ctx context.Context, threepid string,
localpart string, serverName gomatrixserverlib.ServerName,
medium string,
) (err error) { ) (err error) {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
user, err := d.ThreePIDs.SelectLocalpartForThreePID( user, _, err := d.ThreePIDs.SelectLocalpartForThreePID(
ctx, txn, threepid, medium, ctx, txn, threepid, medium,
) )
if err != nil { if err != nil {
@ -290,7 +304,7 @@ func (d *Database) SaveThreePIDAssociation(
return Err3PIDInUse return Err3PIDInUse
} }
return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart) return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart, serverName)
}) })
} }
@ -313,7 +327,7 @@ func (d *Database) RemoveThreePIDAssociation(
// Returns an error if there was a problem talking to the database. // Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID( func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string, ctx context.Context, threepid string, medium string,
) (localpart string, err error) { ) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium) return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium)
} }
@ -322,16 +336,17 @@ func (d *Database) GetLocalpartForThreePID(
// If no association is known for this user, returns an empty slice. // If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database. // Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart( func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) (threepids []authtypes.ThreePID, err error) { ) (threepids []authtypes.ThreePID, err error) {
return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart) return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName)
} }
// CheckAccountAvailability checks if the username/localpart is already present // CheckAccountAvailability checks if the username/localpart is already present
// in the database. // in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken. // If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) {
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart) _, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return true, nil return true, nil
} }
@ -341,12 +356,12 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin
// GetAccountByLocalpart returns the account associated with the given localpart. // GetAccountByLocalpart returns the account associated with the given localpart.
// This function assumes the request is authenticated or the account data is used only internally. // This function assumes the request is authenticated or the account data is used only internally.
// Returns sql.ErrNoRows if no account exists which matches the given localpart. // Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (*api.Account, error) { ) (*api.Account, error) {
// try to get the account with lowercase localpart (majority) // try to get the account with lowercase localpart (majority)
acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart)) acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart), serverName)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart) // try with localpart as passed by the request acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) // try with localpart as passed by the request
} }
return acc, err return acc, err
} }
@ -359,20 +374,24 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi
} }
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again. // DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) {
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
return d.Accounts.DeactivateAccount(ctx, localpart) return d.Accounts.DeactivateAccount(ctx, localpart, serverName)
}) })
} }
// CreateOpenIDToken persists a new token that was issued for OpenID Connect // CreateOpenIDToken persists a new token that was issued for OpenID Connect
func (d *Database) CreateOpenIDToken( func (d *Database) CreateOpenIDToken(
ctx context.Context, ctx context.Context,
token, localpart string, token, userID string,
) (int64, error) { ) (int64, error) {
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return 0, nil
}
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS) return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, domain, expiresAtMS)
}) })
return expiresAtMS, err return expiresAtMS, err
} }
@ -539,16 +558,19 @@ func (d *Database) GetDeviceByAccessToken(
// GetDeviceByID returns the device matching the given ID. // GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found. // Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID( func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
deviceID string,
) (*api.Device, error) { ) (*api.Device, error) {
return d.Devices.SelectDeviceByID(ctx, localpart, deviceID) return d.Devices.SelectDeviceByID(ctx, localpart, serverName, deviceID)
} }
// GetDevicesByLocalpart returns the devices matching the given localpart. // GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart( func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) ([]api.Device, error) { ) ([]api.Device, error) {
return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, "") return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, serverName, "")
} }
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@ -562,18 +584,18 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap
// If no device ID is given one is generated. // If no device ID is given one is generated.
// Returns the device on success. // Returns the device on success.
func (d *Database) CreateDevice( func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
displayName *string, ipAddr, userAgent string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) { ) (dev *api.Device, returnErr error) {
if deviceID != nil { if deviceID != nil {
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error var err error
// Revoke existing tokens for this device // Revoke existing tokens for this device
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart); err != nil { if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil {
return err return err
} }
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
return err return err
}) })
} else { } else {
@ -588,7 +610,7 @@ func (d *Database) CreateDevice(
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error var err error
dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
return err return err
}) })
if returnErr == nil { if returnErr == nil {
@ -614,10 +636,12 @@ func generateDeviceID() (string, error) {
// UpdateDevice updates the given device with the display name. // UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success. // Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice( func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
deviceID string, displayName *string,
) error { ) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Devices.UpdateDeviceName(ctx, txn, localpart, deviceID, displayName) return d.Devices.UpdateDeviceName(ctx, txn, localpart, serverName, deviceID, displayName)
}) })
} }
@ -626,10 +650,12 @@ func (d *Database) UpdateDevice(
// If the devices don't exist, it will not return an error // If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error. // If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices( func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
devices []string,
) error { ) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.Devices.DeleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { if err := d.Devices.DeleteDevices(ctx, txn, localpart, serverName, devices); err != sql.ErrNoRows {
return err return err
} }
return nil return nil
@ -640,14 +666,16 @@ func (d *Database) RemoveDevices(
// database matching the given user ID localpart. // database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error. // If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices( func (d *Database) RemoveAllDevices(
ctx context.Context, localpart, exceptDeviceID string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
exceptDeviceID string,
) (devices []api.Device, err error) { ) (devices []api.Device, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID)
if err != nil { if err != nil {
return err return err
} }
if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID); err != sql.ErrNoRows {
return err return err
} }
return nil return nil
@ -656,9 +684,9 @@ func (d *Database) RemoveAllDevices(
} }
// UpdateDeviceLastSeen updates a last seen timestamp and the ip address. // UpdateDeviceLastSeen updates a last seen timestamp and the ip address.
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error { func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr, userAgent) return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, serverName, deviceID, ipAddr, userAgent)
}) })
} }
@ -706,38 +734,38 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (
return d.LoginTokens.SelectLoginToken(ctx, token) return d.LoginTokens.SelectLoginToken(ctx, token)
} }
func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error { func (d *Database) InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) return d.Notifications.Insert(ctx, txn, localpart, serverName, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
}) })
} }
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) { func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos) affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, serverName, roomID, pos)
return err return err
}) })
return return
} }
func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, b bool) (affected bool, err error) { func (d *Database) SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, b bool) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b) affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, serverName, roomID, pos, b)
return err return err
}) })
return return
} }
func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { func (d *Database) GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
return d.Notifications.Select(ctx, nil, localpart, fromID, limit, filter) return d.Notifications.Select(ctx, nil, localpart, serverName, fromID, limit, filter)
} }
func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) { func (d *Database) GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error) {
return d.Notifications.SelectCount(ctx, nil, localpart, filter) return d.Notifications.SelectCount(ctx, nil, localpart, serverName, filter)
} }
func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) { func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) {
return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID) return d.Notifications.SelectRoomCounts(ctx, nil, localpart, serverName, roomID)
} }
func (d *Database) DeleteOldNotifications(ctx context.Context) error { func (d *Database) DeleteOldNotifications(ctx context.Context) error {
@ -747,7 +775,8 @@ func (d *Database) DeleteOldNotifications(ctx context.Context) error {
} }
func (d *Database) UpsertPusher( func (d *Database) UpsertPusher(
ctx context.Context, p api.Pusher, localpart string, ctx context.Context, p api.Pusher,
localpart string, serverName gomatrixserverlib.ServerName,
) error { ) error {
data, err := json.Marshal(p.Data) data, err := json.Marshal(p.Data)
if err != nil { if err != nil {
@ -766,25 +795,26 @@ func (d *Database) UpsertPusher(
p.ProfileTag, p.ProfileTag,
p.Language, p.Language,
string(data), string(data),
localpart) localpart,
serverName)
}) })
} }
// GetPushers returns the pushers matching the given localpart. // GetPushers returns the pushers matching the given localpart.
func (d *Database) GetPushers( func (d *Database) GetPushers(
ctx context.Context, localpart string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) ([]api.Pusher, error) { ) ([]api.Pusher, error) {
return d.Pushers.SelectPushers(ctx, nil, localpart) return d.Pushers.SelectPushers(ctx, nil, localpart, serverName)
} }
// RemovePusher deletes one pusher // RemovePusher deletes one pusher
// Invoked when `append` is true and `kind` is null in // Invoked when `append` is true and `kind` is null in
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set // https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
func (d *Database) RemovePusher( func (d *Database) RemovePusher(
ctx context.Context, appid, pushkey, localpart string, ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName,
) error { ) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart) err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart, serverName)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil return nil
} }

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
) )
const accountDataSchema = ` const accountDataSchema = `
@ -28,27 +29,28 @@ const accountDataSchema = `
CREATE TABLE IF NOT EXISTS userapi_account_datas ( CREATE TABLE IF NOT EXISTS userapi_account_datas (
-- The Matrix user ID localpart for this account -- The Matrix user ID localpart for this account
localpart TEXT NOT NULL, localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
-- The room ID for this data (empty string if not specific to a room) -- The room ID for this data (empty string if not specific to a room)
room_id TEXT, room_id TEXT,
-- The account data type -- The account data type
type TEXT NOT NULL, type TEXT NOT NULL,
-- The account data content -- The account data content
content TEXT NOT NULL, content TEXT NOT NULL
PRIMARY KEY(localpart, room_id, type)
); );
CREATE UNIQUE INDEX IF NOT EXISTS userapi_account_datas_idx ON userapi_account_datas(localpart, server_name, room_id, type);
` `
const insertAccountDataSQL = ` const insertAccountDataSQL = `
INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4) INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5)
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4 ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = $5
` `
const selectAccountDataSQL = "" + const selectAccountDataSQL = "" +
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1" "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2"
const selectAccountDataByTypeSQL = "" + const selectAccountDataByTypeSQL = "" +
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3" "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4"
type accountDataStatements struct { type accountDataStatements struct {
db *sql.DB db *sql.DB
@ -73,20 +75,23 @@ func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
} }
func (s *accountDataStatements) InsertAccountData( func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
roomID, dataType string, content json.RawMessage,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, serverName, roomID, dataType, content)
return err return err
} }
func (s *accountDataStatements) SelectAccountData( func (s *accountDataStatements) SelectAccountData(
ctx context.Context, localpart string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) ( ) (
/* global */ map[string]json.RawMessage, /* global */ map[string]json.RawMessage,
/* rooms */ map[string]map[string]json.RawMessage, /* rooms */ map[string]map[string]json.RawMessage,
error, error,
) { ) {
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -117,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData(
} }
func (s *accountDataStatements) SelectAccountDataByType( func (s *accountDataStatements) SelectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
roomID, dataType string,
) (data json.RawMessage, err error) { ) (data json.RawMessage, err error) {
var bytes []byte var bytes []byte
stmt := s.selectAccountDataByTypeStmt stmt := s.selectAccountDataByTypeStmt
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil { if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }

View file

@ -34,7 +34,8 @@ const accountsSchema = `
-- Stores data about accounts. -- Stores data about accounts.
CREATE TABLE IF NOT EXISTS userapi_accounts ( CREATE TABLE IF NOT EXISTS userapi_accounts (
-- The Matrix user ID localpart for this account -- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY, localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
-- When this account was first created, as a unix timestamp (ms resolution). -- When this account was first created, as a unix timestamp (ms resolution).
created_ts BIGINT NOT NULL, created_ts BIGINT NOT NULL,
-- The password hash for this account. Can be NULL if this is a passwordless account. -- The password hash for this account. Can be NULL if this is a passwordless account.
@ -48,25 +49,27 @@ CREATE TABLE IF NOT EXISTS userapi_accounts (
-- TODO: -- TODO:
-- upgraded_ts, devices, any email reset stuff? -- upgraded_ts, devices, any email reset stuff?
); );
CREATE UNIQUE INDEX IF NOT EXISTS userapi_accounts_idx ON userapi_accounts(localpart, server_name);
` `
const insertAccountSQL = "" + const insertAccountSQL = "" +
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" "INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)"
const updatePasswordSQL = "" + const updatePasswordSQL = "" +
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2" "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3"
const deactivateAccountSQL = "" + const deactivateAccountSQL = "" +
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1" "UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1 AND server_name = $2"
const selectAccountByLocalpartSQL = "" + const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1" "SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
const selectPasswordHashSQL = "" + const selectPasswordHashSQL = "" +
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0" "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = 0"
const selectNewNumericLocalpartSQL = "" + const selectNewNumericLocalpartSQL = "" +
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0" "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0 AND server_name = $1"
type accountsStatements struct { type accountsStatements struct {
db *sql.DB db *sql.DB
@ -119,16 +122,17 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) InsertAccount( func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName,
hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) { ) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt stmt := s.insertAccountStmt
var err error var err error
if accountType != api.AccountTypeAppService { if accountType != api.AccountTypeAppService {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType)
} else { } else {
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -136,42 +140,43 @@ func (s *accountsStatements) InsertAccount(
return &api.Account{ return &api.Account{
Localpart: localpart, Localpart: localpart,
UserID: userutil.MakeUserID(localpart, s.serverName), UserID: userutil.MakeUserID(localpart, serverName),
ServerName: s.serverName, ServerName: serverName,
AppServiceID: appserviceID, AppServiceID: appserviceID,
AccountType: accountType, AccountType: accountType,
}, nil }, nil
} }
func (s *accountsStatements) UpdatePassword( func (s *accountsStatements) UpdatePassword(
ctx context.Context, localpart, passwordHash string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
passwordHash string,
) (err error) { ) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName)
return return
} }
func (s *accountsStatements) DeactivateAccount( func (s *accountsStatements) DeactivateAccount(
ctx context.Context, localpart string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (err error) { ) (err error) {
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName)
return return
} }
func (s *accountsStatements) SelectPasswordHash( func (s *accountsStatements) SelectPasswordHash(
ctx context.Context, localpart string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (hash string, err error) { ) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash)
return return
} }
func (s *accountsStatements) SelectAccountByLocalpart( func (s *accountsStatements) SelectAccountByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
) (*api.Account, error) { ) (*api.Account, error) {
var appserviceIDPtr sql.NullString var appserviceIDPtr sql.NullString
var acc api.Account var acc api.Account
stmt := s.selectAccountByLocalpartStmt stmt := s.selectAccountByLocalpartStmt
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db") log.WithError(err).Error("Unable to retrieve user from the db")
@ -182,20 +187,18 @@ func (s *accountsStatements) SelectAccountByLocalpart(
acc.AppServiceID = appserviceIDPtr.String acc.AppServiceID = appserviceIDPtr.String
} }
acc.UserID = userutil.MakeUserID(localpart, s.serverName) acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName)
acc.ServerName = s.serverName
return &acc, nil return &acc, nil
} }
func (s *accountsStatements) SelectNewNumericLocalpart( func (s *accountsStatements) SelectNewNumericLocalpart(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (id int64, err error) { ) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt stmt := s.selectNewNumericLocalpartStmt
if txn != nil { if txn != nil {
stmt = sqlutil.TxStmt(txn, stmt) stmt = sqlutil.TxStmt(txn, stmt)
} }
err = stmt.QueryRowContext(ctx).Scan(&id) err = stmt.QueryRowContext(ctx, serverName).Scan(&id)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return 1, nil return 1, nil
} }

View file

@ -11,6 +11,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error {
ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
CREATE TABLE userapi_accounts ( CREATE TABLE userapi_accounts (
localpart TEXT NOT NULL PRIMARY KEY, localpart TEXT NOT NULL PRIMARY KEY,
server_name TEXT NOT NULL,
created_ts BIGINT NOT NULL, created_ts BIGINT NOT NULL,
password_hash TEXT, password_hash TEXT,
appservice_id TEXT, appservice_id TEXT,

View file

@ -14,6 +14,7 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
session_id INTEGER, session_id INTEGER,
device_id TEXT , device_id TEXT ,
localpart TEXT , localpart TEXT ,
server_name TEXT NOT NULL,
created_ts BIGINT, created_ts BIGINT,
display_name TEXT, display_name TEXT,
last_seen_ts BIGINT, last_seen_ts BIGINT,

View file

@ -12,6 +12,7 @@ func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; _, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
CREATE TABLE userapi_accounts ( CREATE TABLE userapi_accounts (
localpart TEXT NOT NULL PRIMARY KEY, localpart TEXT NOT NULL PRIMARY KEY,
server_name TEXT NOT NULL,
created_ts BIGINT NOT NULL, created_ts BIGINT NOT NULL,
password_hash TEXT, password_hash TEXT,
appservice_id TEXT, appservice_id TEXT,

View file

@ -0,0 +1,108 @@
package deltas
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
var serverNamesTables = []string{
"userapi_accounts",
"userapi_account_datas",
"userapi_devices",
"userapi_notifications",
"userapi_openid_tokens",
"userapi_profiles",
"userapi_pushers",
"userapi_threepids",
}
// These tables have a PRIMARY KEY constraint which we need to drop so
// that we can recreate a new unique index that contains the server name.
var serverNamesDropPK = []string{
"userapi_accounts",
"userapi_account_datas",
"userapi_profiles",
}
// These indices are out of date so let's drop them. They will get recreated
// automatically.
var serverNamesDropIndex = []string{
"userapi_pusher_localpart_idx",
"userapi_pusher_app_id_pushkey_localpart_idx",
}
// I know what you're thinking: you're wondering "why doesn't this use $1
// and pass variadic parameters to ExecContext?" — the answer is because
// PostgreSQL doesn't expect the table name to be specified as a substituted
// argument in that way so it results in a syntax error in the query.
func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
for _, table := range serverNamesTables {
q := fmt.Sprintf(
"SELECT COUNT(name) FROM sqlite_schema WHERE type='table' AND name=%s;",
pq.QuoteIdentifier(table),
)
var c int
if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 0 {
continue
}
q = fmt.Sprintf(
"SELECT COUNT(*) FROM pragma_table_info(%s) WHERE name='server_name'",
pq.QuoteIdentifier(table),
)
if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 1 {
logrus.Infof("Table %s already has column, skipping", table)
continue
}
if c == 0 {
q = fmt.Sprintf(
"ALTER TABLE %s ADD COLUMN server_name TEXT NOT NULL DEFAULT '';",
pq.QuoteIdentifier(table),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("add server name to %q error: %w", table, err)
}
}
}
for _, table := range serverNamesDropPK {
q := fmt.Sprintf(
"SELECT COUNT(name), sql FROM sqlite_schema WHERE type='table' AND name=%s;",
pq.QuoteIdentifier(table),
)
var c int
var sql string
if err := tx.QueryRowContext(ctx, q).Scan(&c, &sql); err != nil || c == 0 {
continue
}
q = fmt.Sprintf(`
%s; -- create temporary table
INSERT INTO %s SELECT * FROM %s; -- copy data
DROP TABLE %s; -- drop original table
ALTER TABLE %s RENAME TO %s; -- rename new table
`,
strings.Replace(sql, table, table+"_tmp", 1), // create temporary table
table+"_tmp", table, // copy data
table, // drop original table
table+"_tmp", table, // rename new table
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("drop PK from %q error: %w", table, err)
}
}
for _, index := range serverNamesDropIndex {
q := fmt.Sprintf(
"DROP INDEX IF EXISTS %s;",
pq.QuoteIdentifier(index),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("drop index %q error: %w", index, err)
}
}
return nil
}

View file

@ -0,0 +1,28 @@
package deltas
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/gomatrixserverlib"
)
// I know what you're thinking: you're wondering "why doesn't this use $1
// and pass variadic parameters to ExecContext?" — the answer is because
// PostgreSQL doesn't expect the table name to be specified as a substituted
// argument in that way so it results in a syntax error in the query.
func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
for _, table := range serverNamesTables {
q := fmt.Sprintf(
"UPDATE %s SET server_name = %s WHERE server_name = '';",
pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)),
)
if _, err := tx.ExecContext(ctx, q); err != nil {
return fmt.Errorf("write server names to %q error: %w", table, err)
}
}
return nil
}

View file

@ -40,49 +40,50 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
session_id INTEGER, session_id INTEGER,
device_id TEXT , device_id TEXT ,
localpart TEXT , localpart TEXT ,
server_name TEXT NOT NULL,
created_ts BIGINT, created_ts BIGINT,
display_name TEXT, display_name TEXT,
last_seen_ts BIGINT, last_seen_ts BIGINT,
ip TEXT, ip TEXT,
user_agent TEXT, user_agent TEXT,
UNIQUE (localpart, device_id) UNIQUE (localpart, server_name, device_id)
); );
` `
const insertDeviceSQL = "" + const insertDeviceSQL = "" +
"INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + "INSERT INTO userapi_devices (device_id, localpart, server_name, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"
const selectDevicesCountSQL = "" + const selectDevicesCountSQL = "" +
"SELECT COUNT(access_token) FROM userapi_devices" "SELECT COUNT(access_token) FROM userapi_devices"
const selectDeviceByTokenSQL = "" + const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1" "SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" + const selectDeviceByIDSQL = "" +
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2" "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
const selectDevicesByLocalpartSQL = "" + const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
const updateDeviceNameSQL = "" + const updateDeviceNameSQL = "" +
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4"
const deleteDeviceSQL = "" + const deleteDeviceSQL = "" +
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2" "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3"
const deleteDevicesByLocalpartSQL = "" + const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2" "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3"
const deleteDevicesSQL = "" + const deleteDevicesSQL = "" +
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id IN ($2)" "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id IN ($3)"
const selectDevicesByIDSQL = "" + const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" "SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
const updateDeviceLastSeen = "" + const updateDeviceLastSeen = "" +
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
type devicesStatements struct { type devicesStatements struct {
db *sql.DB db *sql.DB
@ -135,8 +136,9 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
// Returns an error if the user already has a device with the given device ID. // Returns an error if the user already has a device with the given device ID.
// Returns the device on success. // Returns the device on success.
func (s *devicesStatements) InsertDevice( func (s *devicesStatements) InsertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, ctx context.Context, txn *sql.Tx, id string,
displayName *string, ipAddr, userAgent string, localpart string, serverName gomatrixserverlib.ServerName,
accessToken string, displayName *string, ipAddr, userAgent string,
) (*api.Device, error) { ) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64 var sessionID int64
@ -146,12 +148,12 @@ func (s *devicesStatements) InsertDevice(
return nil, err return nil, err
} }
sessionID++ sessionID++
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
return nil, err return nil, err
} }
return &api.Device{ return &api.Device{
ID: id, ID: id,
UserID: userutil.MakeUserID(localpart, s.serverName), UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken, AccessToken: accessToken,
SessionID: sessionID, SessionID: sessionID,
LastSeenTS: createdTimeMS, LastSeenTS: createdTimeMS,
@ -161,44 +163,52 @@ func (s *devicesStatements) InsertDevice(
} }
func (s *devicesStatements) DeleteDevice( func (s *devicesStatements) DeleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string, ctx context.Context, txn *sql.Tx, id string,
localpart string, serverName gomatrixserverlib.ServerName,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
_, err := stmt.ExecContext(ctx, id, localpart) _, err := stmt.ExecContext(ctx, id, localpart, serverName)
return err return err
} }
func (s *devicesStatements) DeleteDevices( func (s *devicesStatements) DeleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
devices []string,
) error { ) error {
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1) orig := strings.Replace(deleteDevicesSQL, "($3)", sqlutil.QueryVariadicOffset(len(devices), 2), 1)
prep, err := s.db.Prepare(orig) prep, err := s.db.Prepare(orig)
if err != nil { if err != nil {
return err return err
} }
stmt := sqlutil.TxStmt(txn, prep) stmt := sqlutil.TxStmt(txn, prep)
params := make([]interface{}, len(devices)+1) params := make([]interface{}, len(devices)+2)
params[0] = localpart params[0] = localpart
params[1] = serverName
for i, v := range devices { for i, v := range devices {
params[i+1] = v params[i+2] = v
} }
_, err = stmt.ExecContext(ctx, params...) _, err = stmt.ExecContext(ctx, params...)
return err return err
} }
func (s *devicesStatements) DeleteDevicesByLocalpart( func (s *devicesStatements) DeleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
exceptDeviceID string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID) _, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID)
return err return err
} }
func (s *devicesStatements) UpdateDeviceName( func (s *devicesStatements) UpdateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
deviceID string, displayName *string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) _, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID)
return err return err
} }
@ -207,10 +217,11 @@ func (s *devicesStatements) SelectDeviceByToken(
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device var dev api.Device
var localpart string var localpart string
var serverName gomatrixserverlib.ServerName
stmt := s.selectDeviceByTokenStmt stmt := s.selectDeviceByTokenStmt
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart) err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName)
if err == nil { if err == nil {
dev.UserID = userutil.MakeUserID(localpart, s.serverName) dev.UserID = userutil.MakeUserID(localpart, serverName)
dev.AccessToken = accessToken dev.AccessToken = accessToken
} }
return &dev, err return &dev, err
@ -219,16 +230,18 @@ func (s *devicesStatements) SelectDeviceByToken(
// selectDeviceByID retrieves a device from the database with the given user // selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID // localpart and deviceID
func (s *devicesStatements) SelectDeviceByID( func (s *devicesStatements) SelectDeviceByID(
ctx context.Context, localpart, deviceID string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
deviceID string,
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device var dev api.Device
var displayName, ip sql.NullString var displayName, ip sql.NullString
stmt := s.selectDeviceByIDStmt stmt := s.selectDeviceByIDStmt
var lastseenTS sql.NullInt64 var lastseenTS sql.NullInt64
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip) err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip)
if err == nil { if err == nil {
dev.ID = deviceID dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName) dev.UserID = userutil.MakeUserID(localpart, serverName)
if displayName.Valid { if displayName.Valid {
dev.DisplayName = displayName.String dev.DisplayName = displayName.String
} }
@ -243,10 +256,12 @@ func (s *devicesStatements) SelectDeviceByID(
} }
func (s *devicesStatements) SelectDevicesByLocalpart( func (s *devicesStatements) SelectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
exceptDeviceID string,
) ([]api.Device, error) { ) ([]api.Device, error) {
devices := []api.Device{} devices := []api.Device{}
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID) rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID)
if err != nil { if err != nil {
return devices, err return devices, err
@ -276,7 +291,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
dev.UserAgent = useragent.String dev.UserAgent = useragent.String
} }
dev.UserID = userutil.MakeUserID(localpart, s.serverName) dev.UserID = userutil.MakeUserID(localpart, serverName)
devices = append(devices, dev) devices = append(devices, dev)
} }
@ -298,10 +313,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
var devices []api.Device var devices []api.Device
var dev api.Device var dev api.Device
var localpart string var localpart string
var serverName gomatrixserverlib.ServerName
var displayName sql.NullString var displayName sql.NullString
var lastseents sql.NullInt64 var lastseents sql.NullInt64
for rows.Next() { for rows.Next() {
if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil { if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil {
return nil, err return nil, err
} }
if displayName.Valid { if displayName.Valid {
@ -310,15 +326,15 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
if lastseents.Valid { if lastseents.Valid {
dev.LastSeenTS = lastseents.Int64 dev.LastSeenTS = lastseents.Int64
} }
dev.UserID = userutil.MakeUserID(localpart, s.serverName) dev.UserID = userutil.MakeUserID(localpart, serverName)
devices = append(devices, dev) devices = append(devices, dev)
} }
return devices, rows.Err() return devices, rows.Err()
} }
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error { func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
lastSeenTs := time.Now().UnixNano() / 1000000 lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID) _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID)
return err return err
} }

View file

@ -43,6 +43,7 @@ const notificationSchema = `
CREATE TABLE IF NOT EXISTS userapi_notifications ( CREATE TABLE IF NOT EXISTS userapi_notifications (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
localpart TEXT NOT NULL, localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
stream_pos BIGINT NOT NULL, stream_pos BIGINT NOT NULL,
@ -52,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications (
read BOOLEAN NOT NULL DEFAULT FALSE read BOOLEAN NOT NULL DEFAULT FALSE
); );
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id); CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id); CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id); CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id);
` `
const insertNotificationSQL = "" + const insertNotificationSQL = "" +
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" "INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
const deleteNotificationsUpToSQL = "" + const deleteNotificationsUpToSQL = "" +
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" "DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4"
const updateNotificationReadSQL = "" + const updateNotificationReadSQL = "" +
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1"
const selectNotificationSQL = "" + const selectNotificationSQL = "" +
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" +
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + "(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" +
") AND NOT read ORDER BY localpart, id LIMIT $4" ") AND NOT read ORDER BY localpart, id LIMIT $5"
const selectNotificationCountSQL = "" + const selectNotificationCountSQL = "" +
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" + "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" +
"(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
") AND NOT read" ") AND NOT read"
const selectRoomNotificationCountsSQL = "" + const selectRoomNotificationCountsSQL = "" +
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " + "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
"WHERE localpart = $1 AND room_id = $2 AND NOT read" "WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read"
const cleanNotificationsSQL = "" + const cleanNotificationsSQL = "" +
"DELETE FROM userapi_notifications WHERE" + "DELETE FROM userapi_notifications WHERE" +
@ -111,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
} }
// Insert inserts a notification into the database. // Insert inserts a notification into the database.
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error { func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS roomID, tsMS := n.RoomID, n.TS
nn := *n nn := *n
// Clears out fields that have their own columns to (1) shrink the // Clears out fields that have their own columns to (1) shrink the
@ -122,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
if err != nil { if err != nil {
return err return err
} }
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs))
return err return err
} }
// DeleteUpTo deletes all previous notifications, up to and including the event. // DeleteUpTo deletes all previous notifications, up to and including the event.
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) { func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -141,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
} }
// UpdateRead updates the "read" value for an event. // UpdateRead updates the "read" value for an event.
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) { func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -154,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l
return nrows > 0, nil return nrows > 0, nil
} }
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit) rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@ -197,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
return notifs, maxID, rows.Err() return notifs, maxID, rows.Err()
} }
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) { func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) {
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count) err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count)
return return
} }
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) { func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) {
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight) err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight)
return return
} }

View file

@ -3,6 +3,7 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
@ -18,16 +19,17 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
token TEXT NOT NULL PRIMARY KEY, token TEXT NOT NULL PRIMARY KEY,
-- The Matrix user ID for this account -- The Matrix user ID for this account
localpart TEXT NOT NULL, localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
-- When the token expires, as a unix timestamp (ms resolution). -- When the token expires, as a unix timestamp (ms resolution).
token_expires_at_ms BIGINT NOT NULL token_expires_at_ms BIGINT NOT NULL
); );
` `
const insertOpenIDTokenSQL = "" + const insertOpenIDTokenSQL = "" +
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" "INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
const selectOpenIDTokenSQL = "" + const selectOpenIDTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" "SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
type openIDTokenStatements struct { type openIDTokenStatements struct {
db *sql.DB db *sql.DB
@ -56,11 +58,11 @@ func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (
func (s *openIDTokenStatements) InsertOpenIDToken( func (s *openIDTokenStatements) InsertOpenIDToken(
ctx context.Context, ctx context.Context,
txn *sql.Tx, txn *sql.Tx,
token, localpart string, token, localpart string, serverName gomatrixserverlib.ServerName,
expiresAtMS int64, expiresAtMS int64,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS) _, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS)
return return
} }
@ -71,10 +73,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
token string, token string,
) (*api.OpenIDTokenAttributes, error) { ) (*api.OpenIDTokenAttributes, error) {
var openIDTokenAttrs api.OpenIDTokenAttributes var openIDTokenAttrs api.OpenIDTokenAttributes
var localpart string
var serverName gomatrixserverlib.ServerName
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
&openIDTokenAttrs.UserID, &localpart, &serverName,
&openIDTokenAttrs.ExpiresAtMS, &openIDTokenAttrs.ExpiresAtMS,
) )
openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve token from the db") log.WithError(err).Error("Unable to retrieve token from the db")

View file

@ -23,36 +23,40 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
) )
const profilesSchema = ` const profilesSchema = `
-- Stores data about accounts profiles. -- Stores data about accounts profiles.
CREATE TABLE IF NOT EXISTS userapi_profiles ( CREATE TABLE IF NOT EXISTS userapi_profiles (
-- The Matrix user ID localpart for this account -- The Matrix user ID localpart for this account
localpart TEXT NOT NULL PRIMARY KEY, localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
-- The display name for this account -- The display name for this account
display_name TEXT, display_name TEXT,
-- The URL of the avatar for this account -- The URL of the avatar for this account
avatar_url TEXT avatar_url TEXT
); );
CREATE UNIQUE INDEX IF NOT EXISTS userapi_profiles_idx ON userapi_profiles(localpart, server_name);
` `
const insertProfileSQL = "" + const insertProfileSQL = "" +
"INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" "INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)"
const selectProfileByLocalpartSQL = "" + const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2"
const setAvatarURLSQL = "" + const setAvatarURLSQL = "" +
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" + "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2 AND server_name = $3" +
" RETURNING display_name" " RETURNING display_name"
const setDisplayNameSQL = "" + const setDisplayNameSQL = "" +
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" + "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2 AND server_name = $3" +
" RETURNING avatar_url" " RETURNING avatar_url"
const selectProfilesBySearchSQL = "" + const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
type profilesStatements struct { type profilesStatements struct {
db *sql.DB db *sql.DB
@ -83,18 +87,20 @@ func NewSQLiteProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables.P
} }
func (s *profilesStatements) InsertProfile( func (s *profilesStatements) InsertProfile(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "")
return err return err
} }
func (s *profilesStatements) SelectProfileByLocalpart( func (s *profilesStatements) SelectProfileByLocalpart(
ctx context.Context, localpart string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
var profile authtypes.Profile var profile authtypes.Profile
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan( err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan(
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL, &profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -103,13 +109,16 @@ func (s *profilesStatements) SelectProfileByLocalpart(
} }
func (s *profilesStatements) SetAvatarURL( func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
avatarURL string,
) (*authtypes.Profile, bool, error) { ) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{ profile := &authtypes.Profile{
Localpart: localpart, Localpart: localpart,
ServerName: string(serverName),
AvatarURL: avatarURL, AvatarURL: avatarURL,
} }
old, err := s.SelectProfileByLocalpart(ctx, localpart) old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName)
if err != nil { if err != nil {
return old, false, err return old, false, err
} }
@ -117,18 +126,21 @@ func (s *profilesStatements) SetAvatarURL(
return old, false, nil return old, false, nil
} }
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName) err = stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName)
return profile, true, err return profile, true, err
} }
func (s *profilesStatements) SetDisplayName( func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
displayName string,
) (*authtypes.Profile, bool, error) { ) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{ profile := &authtypes.Profile{
Localpart: localpart, Localpart: localpart,
ServerName: string(serverName),
DisplayName: displayName, DisplayName: displayName,
} }
old, err := s.SelectProfileByLocalpart(ctx, localpart) old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName)
if err != nil { if err != nil {
return old, false, err return old, false, err
} }
@ -136,7 +148,7 @@ func (s *profilesStatements) SetDisplayName(
return old, false, nil return old, false, nil
} }
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL) err = stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL)
return profile, true, err return profile, true, err
} }
@ -154,7 +166,7 @@ func (s *profilesStatements) SelectProfilesBySearch(
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
for rows.Next() { for rows.Next() {
var profile authtypes.Profile var profile authtypes.Profile
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil {
return nil, err return nil, err
} }
if profile.Localpart != s.serverNoticesLocalpart { if profile.Localpart != s.serverNoticesLocalpart {

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
) )
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
@ -33,6 +34,7 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
-- The Matrix user ID localpart for this pusher -- The Matrix user ID localpart for this pusher
localpart TEXT NOT NULL, localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
session_id BIGINT DEFAULT NULL, session_id BIGINT DEFAULT NULL,
profile_tag TEXT, profile_tag TEXT,
kind TEXT NOT NULL, kind TEXT NOT NULL,
@ -49,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey); CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
-- For faster retrieving by localpart. -- For faster retrieving by localpart.
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart); CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name);
-- Pushkey must be unique for a given user and app. -- Pushkey must be unique for a given user and app.
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart); CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name);
` `
const insertPusherSQL = "" + const insertPusherSQL = "" +
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + "INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" +
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" "ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12"
const selectPushersSQL = "" + const selectPushersSQL = "" +
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1" "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2"
const deletePusherSQL = "" + const deletePusherSQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3" "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4"
const deletePushersByAppIdAndPushKeySQL = "" + const deletePushersByAppIdAndPushKeySQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2" "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
@ -95,18 +97,19 @@ type pushersStatements struct {
// Returns nil error success. // Returns nil error success.
func (s *pushersStatements) InsertPusher( func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64, ctx context.Context, txn *sql.Tx, session_id int64,
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data,
localpart string, serverName gomatrixserverlib.ServerName,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
logrus.Debugf("Created pusher %d", session_id)
return err return err
} }
func (s *pushersStatements) SelectPushers( func (s *pushersStatements) SelectPushers(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
) ([]api.Pusher, error) { ) ([]api.Pusher, error) {
pushers := []api.Pusher{} pushers := []api.Pusher{}
rows, err := s.selectPushersStmt.QueryContext(ctx, localpart) rows, err := s.selectPushersStmt.QueryContext(ctx, localpart, serverName)
if err != nil { if err != nil {
return pushers, err return pushers, err
@ -143,9 +146,10 @@ func (s *pushersStatements) SelectPushers(
// deletePusher removes a single pusher by pushkey and user localpart. // deletePusher removes a single pusher by pushkey and user localpart.
func (s *pushersStatements) DeletePusher( func (s *pushersStatements) DeletePusher(
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, ctx context.Context, txn *sql.Tx, appid, pushkey,
localpart string, serverName gomatrixserverlib.ServerName,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart) _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName)
return err return err
} }

View file

@ -15,6 +15,8 @@
package sqlite3 package sqlite3
import ( import (
"context"
"database/sql"
"fmt" "fmt"
"time" "time"
@ -41,18 +43,24 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
Up: deltas.UpRenameTables, Up: deltas.UpRenameTables,
Down: deltas.DownRenameTables, Down: deltas.DownRenameTables,
}) })
m.AddMigrations(sqlutil.Migration{
Version: "userapi: server names",
Up: func(ctx context.Context, txn *sql.Tx) error {
return deltas.UpServerNames(ctx, txn, serverName)
},
})
if err = m.Up(base.Context()); err != nil { if err = m.Up(base.Context()); err != nil {
return nil, err return nil, err
} }
accountDataTable, err := NewSQLiteAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
}
accountsTable, err := NewSQLiteAccountsTable(db, serverName) accountsTable, err := NewSQLiteAccountsTable(db, serverName)
if err != nil { if err != nil {
return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err) return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err)
} }
accountDataTable, err := NewSQLiteAccountDataTable(db)
if err != nil {
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
}
devicesTable, err := NewSQLiteDevicesTable(db, serverName) devicesTable, err := NewSQLiteDevicesTable(db, serverName)
if err != nil { if err != nil {
return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err) return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err)
@ -93,6 +101,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil { if err != nil {
return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err) return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err)
} }
m = sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "userapi: server names populate",
Up: func(ctx context.Context, txn *sql.Tx) error {
return deltas.UpServerNamesPopulate(ctx, txn, serverName)
},
})
if err = m.Up(base.Context()); err != nil {
return nil, err
}
return &shared.Database{ return &shared.Database{
AccountDatas: accountDataTable, AccountDatas: accountDataTable,
Accounts: accountsTable, Accounts: accountsTable,

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
) )
@ -34,21 +35,22 @@ CREATE TABLE IF NOT EXISTS userapi_threepids (
medium TEXT NOT NULL DEFAULT 'email', medium TEXT NOT NULL DEFAULT 'email',
-- The localpart of the Matrix user ID associated to this 3PID -- The localpart of the Matrix user ID associated to this 3PID
localpart TEXT NOT NULL, localpart TEXT NOT NULL,
server_name TEXT NOT NULL,
PRIMARY KEY(threepid, medium) PRIMARY KEY(threepid, medium)
); );
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart); CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart, server_name);
` `
const selectLocalpartForThreePIDSQL = "" + const selectLocalpartForThreePIDSQL = "" +
"SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2" "SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
const selectThreePIDsForLocalpartSQL = "" + const selectThreePIDsForLocalpartSQL = "" +
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1" "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2"
const insertThreePIDSQL = "" + const insertThreePIDSQL = "" +
"INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)" "INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)"
const deleteThreePIDSQL = "" + const deleteThreePIDSQL = "" +
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2" "DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
@ -79,19 +81,20 @@ func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
func (s *threepidStatements) SelectLocalpartForThreePID( func (s *threepidStatements) SelectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string, ctx context.Context, txn *sql.Tx, threepid string, medium string,
) (localpart string, err error) { ) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", "", nil
} }
return return
} }
func (s *threepidStatements) SelectThreePIDsForLocalpart( func (s *threepidStatements) SelectThreePIDsForLocalpart(
ctx context.Context, localpart string, ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
) (threepids []authtypes.ThreePID, err error) { ) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName)
if err != nil { if err != nil {
return return
} }
@ -113,10 +116,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart(
} }
func (s *threepidStatements) InsertThreePID( func (s *threepidStatements) InsertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ctx context.Context, txn *sql.Tx, threepid, medium,
localpart string, serverName gomatrixserverlib.ServerName,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium, localpart) _, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName)
return err return err
} }

View file

@ -50,25 +50,25 @@ func Test_AccountData(t *testing.T) {
db, close := mustCreateDatabase(t, dbType) db, close := mustCreateDatabase(t, dbType)
defer close() defer close()
alice := test.NewUser(t) alice := test.NewUser(t)
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
room := test.NewRoom(t, alice) room := test.NewRoom(t, alice)
events := room.Events() events := room.Events()
contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID())) contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID()))
err = db.SaveAccountData(ctx, localpart, room.ID, "m.fully_read", contentRoom) err = db.SaveAccountData(ctx, localpart, domain, room.ID, "m.fully_read", contentRoom)
assert.NoError(t, err, "unable to save account data") assert.NoError(t, err, "unable to save account data")
contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID)) contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID))
err = db.SaveAccountData(ctx, localpart, "", "im.vector.setting.breadcrumbs", contentGlobal) err = db.SaveAccountData(ctx, localpart, domain, "", "im.vector.setting.breadcrumbs", contentGlobal)
assert.NoError(t, err, "unable to save account data") assert.NoError(t, err, "unable to save account data")
accountData, err := db.GetAccountDataByType(ctx, localpart, room.ID, "m.fully_read") accountData, err := db.GetAccountDataByType(ctx, localpart, domain, room.ID, "m.fully_read")
assert.NoError(t, err, "unable to get account data by type") assert.NoError(t, err, "unable to get account data by type")
assert.Equal(t, contentRoom, accountData) assert.Equal(t, contentRoom, accountData)
globalData, roomData, err := db.GetAccountData(ctx, localpart) globalData, roomData, err := db.GetAccountData(ctx, localpart, domain)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"]) assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"])
assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"]) assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"])
@ -81,78 +81,78 @@ func Test_Accounts(t *testing.T) {
db, close := mustCreateDatabase(t, dbType) db, close := mustCreateDatabase(t, dbType)
defer close() defer close()
alice := test.NewUser(t) alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin) accAlice, err := db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin)
assert.NoError(t, err, "failed to create account") assert.NoError(t, err, "failed to create account")
// verify the newly create account is the same as returned by CreateAccount // verify the newly create account is the same as returned by CreateAccount
var accGet *api.Account var accGet *api.Account
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing") accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "testing")
assert.NoError(t, err, "failed to get account by password") assert.NoError(t, err, "failed to get account by password")
assert.Equal(t, accAlice, accGet) assert.Equal(t, accAlice, accGet)
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart) accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "failed to get account by localpart") assert.NoError(t, err, "failed to get account by localpart")
assert.Equal(t, accAlice, accGet) assert.Equal(t, accAlice, accGet)
// check account availability // check account availability
available, err := db.CheckAccountAvailability(ctx, aliceLocalpart) available, err := db.CheckAccountAvailability(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "failed to checkout account availability") assert.NoError(t, err, "failed to checkout account availability")
assert.Equal(t, false, available) assert.Equal(t, false, available)
available, err = db.CheckAccountAvailability(ctx, "unusedname") available, err = db.CheckAccountAvailability(ctx, "unusedname", aliceDomain)
assert.NoError(t, err, "failed to checkout account availability") assert.NoError(t, err, "failed to checkout account availability")
assert.Equal(t, true, available) assert.Equal(t, true, available)
// get guest account numeric aliceLocalpart // get guest account numeric aliceLocalpart
first, err := db.GetNewNumericLocalpart(ctx) first, err := db.GetNewNumericLocalpart(ctx, aliceDomain)
assert.NoError(t, err, "failed to get new numeric localpart") assert.NoError(t, err, "failed to get new numeric localpart")
// Create a new account to verify the numeric localpart is updated // Create a new account to verify the numeric localpart is updated
_, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest) _, err = db.CreateAccount(ctx, "", aliceDomain, "testing", "", api.AccountTypeGuest)
assert.NoError(t, err, "failed to create account") assert.NoError(t, err, "failed to create account")
second, err := db.GetNewNumericLocalpart(ctx) second, err := db.GetNewNumericLocalpart(ctx, aliceDomain)
assert.NoError(t, err) assert.NoError(t, err)
assert.Greater(t, second, first) assert.Greater(t, second, first)
// update password for alice // update password for alice
err = db.SetPassword(ctx, aliceLocalpart, "newPassword") err = db.SetPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
assert.NoError(t, err, "failed to update password") assert.NoError(t, err, "failed to update password")
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword") accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
assert.NoError(t, err, "failed to get account by new password") assert.NoError(t, err, "failed to get account by new password")
assert.Equal(t, accAlice, accGet) assert.Equal(t, accAlice, accGet)
// deactivate account // deactivate account
err = db.DeactivateAccount(ctx, aliceLocalpart) err = db.DeactivateAccount(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "failed to deactivate account") assert.NoError(t, err, "failed to deactivate account")
// This should fail now, as the account is deactivated // This should fail now, as the account is deactivated
_, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword") _, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
assert.Error(t, err, "expected an error, got none") assert.Error(t, err, "expected an error, got none")
_, err = db.GetAccountByLocalpart(ctx, "unusename") _, err = db.GetAccountByLocalpart(ctx, "unusename", aliceDomain)
assert.Error(t, err, "expected an error for non existent localpart") assert.Error(t, err, "expected an error for non existent localpart")
// create an empty localpart; this should never happen, but is required to test getting a numeric localpart // create an empty localpart; this should never happen, but is required to test getting a numeric localpart
// if there's already a user without a localpart in the database // if there's already a user without a localpart in the database
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeUser) _, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeUser)
assert.NoError(t, err) assert.NoError(t, err)
// test getting a numeric localpart, with an existing user without a localpart // test getting a numeric localpart, with an existing user without a localpart
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest) _, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest)
assert.NoError(t, err) assert.NoError(t, err)
// Create a user with a high numeric localpart, out of range for the Postgres integer (2147483647) type // Create a user with a high numeric localpart, out of range for the Postgres integer (2147483647) type
_, err = db.CreateAccount(ctx, "2147483650", "", "", api.AccountTypeUser) _, err = db.CreateAccount(ctx, "2147483650", aliceDomain, "", "", api.AccountTypeUser)
assert.NoError(t, err) assert.NoError(t, err)
// Now try to create a new guest user // Now try to create a new guest user
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest) _, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest)
assert.NoError(t, err) assert.NoError(t, err)
}) })
} }
func Test_Devices(t *testing.T) { func Test_Devices(t *testing.T) {
alice := test.NewUser(t) alice := test.NewUser(t)
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
deviceID := util.RandomString(8) deviceID := util.RandomString(8)
accessToken := util.RandomString(16) accessToken := util.RandomString(16)
@ -161,10 +161,10 @@ func Test_Devices(t *testing.T) {
db, close := mustCreateDatabase(t, dbType) db, close := mustCreateDatabase(t, dbType)
defer close() defer close()
deviceWithID, err := db.CreateDevice(ctx, localpart, &deviceID, accessToken, nil, "", "") deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "")
assert.NoError(t, err, "unable to create deviceWithoutID") assert.NoError(t, err, "unable to create deviceWithoutID")
gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID) gotDevice, err := db.GetDeviceByID(ctx, localpart, domain, deviceID)
assert.NoError(t, err, "unable to get device by id") assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields
@ -174,14 +174,14 @@ func Test_Devices(t *testing.T) {
// create a device without existing device ID // create a device without existing device ID
accessToken = util.RandomString(16) accessToken = util.RandomString(16)
deviceWithoutID, err := db.CreateDevice(ctx, localpart, nil, accessToken, nil, "", "") deviceWithoutID, err := db.CreateDevice(ctx, localpart, domain, nil, accessToken, nil, "", "")
assert.NoError(t, err, "unable to create deviceWithoutID") assert.NoError(t, err, "unable to create deviceWithoutID")
gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID) gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, domain, deviceWithoutID.ID)
assert.NoError(t, err, "unable to get device by id") assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields
// Get devices // Get devices
devices, err := db.GetDevicesByLocalpart(ctx, localpart) devices, err := db.GetDevicesByLocalpart(ctx, localpart, domain)
assert.NoError(t, err, "unable to get devices by localpart") assert.NoError(t, err, "unable to get devices by localpart")
assert.Equal(t, 2, len(devices)) assert.Equal(t, 2, len(devices))
deviceIDs := make([]string, 0, len(devices)) deviceIDs := make([]string, 0, len(devices))
@ -195,15 +195,15 @@ func Test_Devices(t *testing.T) {
// Update device // Update device
newName := "new display name" newName := "new display name"
err = db.UpdateDevice(ctx, localpart, deviceWithID.ID, &newName) err = db.UpdateDevice(ctx, localpart, domain, deviceWithID.ID, &newName)
assert.NoError(t, err, "unable to update device displayname") assert.NoError(t, err, "unable to update device displayname")
updatedAfterTimestamp := time.Now().Unix() updatedAfterTimestamp := time.Now().Unix()
err = db.UpdateDeviceLastSeen(ctx, localpart, deviceWithID.ID, "127.0.0.1", "Element Web") err = db.UpdateDeviceLastSeen(ctx, localpart, domain, deviceWithID.ID, "127.0.0.1", "Element Web")
assert.NoError(t, err, "unable to update device last seen") assert.NoError(t, err, "unable to update device last seen")
deviceWithID.DisplayName = newName deviceWithID.DisplayName = newName
deviceWithID.LastSeenIP = "127.0.0.1" deviceWithID.LastSeenIP = "127.0.0.1"
gotDevice, err = db.GetDeviceByID(ctx, localpart, deviceWithID.ID) gotDevice, err = db.GetDeviceByID(ctx, localpart, domain, deviceWithID.ID)
assert.NoError(t, err, "unable to get device by id") assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, 2, len(devices)) assert.Equal(t, 2, len(devices))
assert.Equal(t, deviceWithID.DisplayName, gotDevice.DisplayName) assert.Equal(t, deviceWithID.DisplayName, gotDevice.DisplayName)
@ -213,20 +213,20 @@ func Test_Devices(t *testing.T) {
// create one more device and remove the devices step by step // create one more device and remove the devices step by step
newDeviceID := util.RandomString(16) newDeviceID := util.RandomString(16)
accessToken = util.RandomString(16) accessToken = util.RandomString(16)
_, err = db.CreateDevice(ctx, localpart, &newDeviceID, accessToken, nil, "", "") _, err = db.CreateDevice(ctx, localpart, domain, &newDeviceID, accessToken, nil, "", "")
assert.NoError(t, err, "unable to create new device") assert.NoError(t, err, "unable to create new device")
devices, err = db.GetDevicesByLocalpart(ctx, localpart) devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain)
assert.NoError(t, err, "unable to get device by id") assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, 3, len(devices)) assert.Equal(t, 3, len(devices))
err = db.RemoveDevices(ctx, localpart, deviceIDs) err = db.RemoveDevices(ctx, localpart, domain, deviceIDs)
assert.NoError(t, err, "unable to remove devices") assert.NoError(t, err, "unable to remove devices")
devices, err = db.GetDevicesByLocalpart(ctx, localpart) devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain)
assert.NoError(t, err, "unable to get device by id") assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, 1, len(devices)) assert.Equal(t, 1, len(devices))
deleted, err := db.RemoveAllDevices(ctx, localpart, "") deleted, err := db.RemoveAllDevices(ctx, localpart, domain, "")
assert.NoError(t, err, "unable to remove all devices") assert.NoError(t, err, "unable to remove all devices")
assert.Equal(t, 1, len(deleted)) assert.Equal(t, 1, len(deleted))
assert.Equal(t, newDeviceID, deleted[0].ID) assert.Equal(t, newDeviceID, deleted[0].ID)
@ -364,7 +364,7 @@ func Test_OpenID(t *testing.T) {
func Test_Profile(t *testing.T) { func Test_Profile(t *testing.T) {
alice := test.NewUser(t) alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@ -372,30 +372,33 @@ func Test_Profile(t *testing.T) {
defer close() defer close()
// create account, which also creates a profile // create account, which also creates a profile
_, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin) _, err = db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin)
assert.NoError(t, err, "failed to create account") assert.NoError(t, err, "failed to create account")
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart) gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get profile by localpart") assert.NoError(t, err, "unable to get profile by localpart")
wantProfile := &authtypes.Profile{Localpart: aliceLocalpart} wantProfile := &authtypes.Profile{
Localpart: aliceLocalpart,
ServerName: string(aliceDomain),
}
assert.Equal(t, wantProfile, gotProfile) assert.Equal(t, wantProfile, gotProfile)
// set avatar & displayname // set avatar & displayname
wantProfile.DisplayName = "Alice" wantProfile.DisplayName = "Alice"
gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, "Alice") gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, aliceDomain, "Alice")
assert.Equal(t, wantProfile, gotProfile) assert.Equal(t, wantProfile, gotProfile)
assert.NoError(t, err, "unable to set displayname") assert.NoError(t, err, "unable to set displayname")
assert.True(t, changed) assert.True(t, changed)
wantProfile.AvatarURL = "mxc://aliceAvatar" wantProfile.AvatarURL = "mxc://aliceAvatar"
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url") assert.NoError(t, err, "unable to set avatar url")
assert.Equal(t, wantProfile, gotProfile) assert.Equal(t, wantProfile, gotProfile)
assert.True(t, changed) assert.True(t, changed)
// Setting the same avatar again doesn't change anything // Setting the same avatar again doesn't change anything
wantProfile.AvatarURL = "mxc://aliceAvatar" wantProfile.AvatarURL = "mxc://aliceAvatar"
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url") assert.NoError(t, err, "unable to set avatar url")
assert.Equal(t, wantProfile, gotProfile) assert.Equal(t, wantProfile, gotProfile)
assert.False(t, changed) assert.False(t, changed)
@ -410,7 +413,7 @@ func Test_Profile(t *testing.T) {
func Test_Pusher(t *testing.T) { func Test_Pusher(t *testing.T) {
alice := test.NewUser(t) alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@ -432,11 +435,11 @@ func Test_Pusher(t *testing.T) {
ProfileTag: util.RandomString(8), ProfileTag: util.RandomString(8),
Language: util.RandomString(2), Language: util.RandomString(2),
} }
err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart) err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to upsert pusher") assert.NoError(t, err, "unable to upsert pusher")
// check it was actually persisted // check it was actually persisted
gotPushers, err = db.GetPushers(ctx, aliceLocalpart) gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get pushers") assert.NoError(t, err, "unable to get pushers")
assert.Equal(t, i+1, len(gotPushers)) assert.Equal(t, i+1, len(gotPushers))
assert.Equal(t, wantPusher, gotPushers[i]) assert.Equal(t, wantPusher, gotPushers[i])
@ -444,16 +447,16 @@ func Test_Pusher(t *testing.T) {
} }
// remove single pusher // remove single pusher
err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart) err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to remove pusher") assert.NoError(t, err, "unable to remove pusher")
gotPushers, err := db.GetPushers(ctx, aliceLocalpart) gotPushers, err := db.GetPushers(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get pushers") assert.NoError(t, err, "unable to get pushers")
assert.Equal(t, 1, len(gotPushers)) assert.Equal(t, 1, len(gotPushers))
// remove last pusher // remove last pusher
err = db.RemovePushers(ctx, appID, pushKeys[1]) err = db.RemovePushers(ctx, appID, pushKeys[1])
assert.NoError(t, err, "unable to remove pusher") assert.NoError(t, err, "unable to remove pusher")
gotPushers, err = db.GetPushers(ctx, aliceLocalpart) gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get pushers") assert.NoError(t, err, "unable to get pushers")
assert.Equal(t, 0, len(gotPushers)) assert.Equal(t, 0, len(gotPushers))
}) })
@ -461,7 +464,7 @@ func Test_Pusher(t *testing.T) {
func Test_ThreePID(t *testing.T) { func Test_ThreePID(t *testing.T) {
alice := test.NewUser(t) alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
@ -469,15 +472,16 @@ func Test_ThreePID(t *testing.T) {
defer close() defer close()
threePID := util.RandomString(8) threePID := util.RandomString(8)
medium := util.RandomString(8) medium := util.RandomString(8)
err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, medium) err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, aliceDomain, medium)
assert.NoError(t, err, "unable to save threepid association") assert.NoError(t, err, "unable to save threepid association")
// get the stored threepid // get the stored threepid
gotLocalpart, err := db.GetLocalpartForThreePID(ctx, threePID, medium) gotLocalpart, gotDomain, err := db.GetLocalpartForThreePID(ctx, threePID, medium)
assert.NoError(t, err, "unable to get localpart for threepid") assert.NoError(t, err, "unable to get localpart for threepid")
assert.Equal(t, aliceLocalpart, gotLocalpart) assert.Equal(t, aliceLocalpart, gotLocalpart)
assert.Equal(t, aliceDomain, gotDomain)
threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart) threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get threepids for localpart") assert.NoError(t, err, "unable to get threepids for localpart")
assert.Equal(t, 1, len(threepids)) assert.Equal(t, 1, len(threepids))
assert.Equal(t, authtypes.ThreePID{ assert.Equal(t, authtypes.ThreePID{
@ -490,7 +494,7 @@ func Test_ThreePID(t *testing.T) {
assert.NoError(t, err, "unexpected error") assert.NoError(t, err, "unexpected error")
// verify it was deleted // verify it was deleted
threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart) threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain)
assert.NoError(t, err, "unable to get threepids for localpart") assert.NoError(t, err, "unable to get threepids for localpart")
assert.Equal(t, 0, len(threepids)) assert.Equal(t, 0, len(threepids))
}) })
@ -498,7 +502,7 @@ func Test_ThreePID(t *testing.T) {
func Test_Notification(t *testing.T) { func Test_Notification(t *testing.T) {
alice := test.NewUser(t) alice := test.NewUser(t)
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err) assert.NoError(t, err)
room := test.NewRoom(t, alice) room := test.NewRoom(t, alice)
room2 := test.NewRoom(t, alice) room2 := test.NewRoom(t, alice)
@ -526,34 +530,34 @@ func Test_Notification(t *testing.T) {
RoomID: roomID, RoomID: roomID,
TS: gomatrixserverlib.AsTimestamp(ts), TS: gomatrixserverlib.AsTimestamp(ts),
} }
err = db.InsertNotification(ctx, aliceLocalpart, eventID, uint64(i+1), nil, notification) err = db.InsertNotification(ctx, aliceLocalpart, aliceDomain, eventID, uint64(i+1), nil, notification)
assert.NoError(t, err, "unable to insert notification") assert.NoError(t, err, "unable to insert notification")
} }
// get notifications // get notifications
count, err := db.GetNotificationCount(ctx, aliceLocalpart, tables.AllNotifications) count, err := db.GetNotificationCount(ctx, aliceLocalpart, aliceDomain, tables.AllNotifications)
assert.NoError(t, err, "unable to get notification count") assert.NoError(t, err, "unable to get notification count")
assert.Equal(t, int64(10), count) assert.Equal(t, int64(10), count)
notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, 0, 15, tables.AllNotifications) notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, aliceDomain, 0, 15, tables.AllNotifications)
assert.NoError(t, err, "unable to get notifications") assert.NoError(t, err, "unable to get notifications")
assert.Equal(t, int64(10), count) assert.Equal(t, int64(10), count)
assert.Equal(t, 10, len(notifs)) assert.Equal(t, 10, len(notifs))
// ... for a specific room // ... for a specific room
total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID) total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID)
assert.NoError(t, err, "unable to get notifications for room") assert.NoError(t, err, "unable to get notifications for room")
assert.Equal(t, int64(4), total) assert.Equal(t, int64(4), total)
// mark notification as read // mark notification as read
affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, room2.ID, 7, true) affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, aliceDomain, room2.ID, 7, true)
assert.NoError(t, err, "unable to set notifications read") assert.NoError(t, err, "unable to set notifications read")
assert.True(t, affected) assert.True(t, affected)
// this should delete 2 notifications // this should delete 2 notifications
affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, room2.ID, 8) affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, aliceDomain, room2.ID, 8)
assert.NoError(t, err, "unable to set notifications read") assert.NoError(t, err, "unable to set notifications read")
assert.True(t, affected) assert.True(t, affected)
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID) total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID)
assert.NoError(t, err, "unable to get notifications for room") assert.NoError(t, err, "unable to get notifications for room")
assert.Equal(t, int64(2), total) assert.Equal(t, int64(2), total)
@ -562,7 +566,7 @@ func Test_Notification(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
// this should now return 0 notifications // this should now return 0 notifications
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID) total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID)
assert.NoError(t, err, "unable to get notifications for room") assert.NoError(t, err, "unable to get notifications for room")
assert.Equal(t, int64(0), total) assert.Equal(t, int64(0), total)
}) })

View file

@ -28,31 +28,31 @@ import (
) )
type AccountDataTable interface { type AccountDataTable interface {
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage) error InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error
SelectAccountData(ctx context.Context, localpart string) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error) SelectAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)
SelectAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error) SelectAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error)
} }
type AccountsTable interface { type AccountsTable interface {
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error) InsertAccount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error) UpdatePassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, passwordHash string) (err error)
DeactivateAccount(ctx context.Context, localpart string) (err error) DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error)
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error) SelectPasswordHash(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (hash string, err error)
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) SelectAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error)
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error) SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (id int64, err error)
} }
type DevicesTable interface { type DevicesTable interface {
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error) InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName) error
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) error
UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string) error UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error
SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error) SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error)
SelectDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) SelectDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error)
SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) ([]api.Device, error) SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) ([]api.Device, error)
SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error
} }
type KeyBackupTable interface { type KeyBackupTable interface {
@ -79,40 +79,40 @@ type LoginTokenTable interface {
} }
type OpenIDTable interface { type OpenIDTable interface {
InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, expiresAtMS int64) (err error) InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64) (err error)
SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
} }
type ProfileTable interface { type ProfileTable interface {
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error InsertProfile(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) error
SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) SelectProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error)
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, bool, error) SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error)
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, bool, error) SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error)
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
} }
type ThreePIDTable interface { type ThreePIDTable interface {
SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, err error) SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error)
SelectThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) SelectThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error)
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error) InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, serverName gomatrixserverlib.ServerName) (err error)
DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error)
} }
type PusherTable interface { type PusherTable interface {
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, serverName gomatrixserverlib.ServerName) error
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error) SelectPushers(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error)
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error
DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error
} }
type NotificationTable interface { type NotificationTable interface {
Clean(ctx context.Context, txn *sql.Tx) error Clean(ctx context.Context, txn *sql.Tx) error
Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error
DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error)
UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error)
Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter NotificationFilter) (int64, error)
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error)
} }
type StatsTable interface { type StatsTable interface {

View file

@ -79,6 +79,7 @@ func mustMakeAccountAndDevice(
accDB tables.AccountsTable, accDB tables.AccountsTable,
devDB tables.DevicesTable, devDB tables.DevicesTable,
localpart string, localpart string,
serverName gomatrixserverlib.ServerName, // nolint:unparam
accType api.AccountType, accType api.AccountType,
userAgent string, userAgent string,
) { ) {
@ -89,11 +90,11 @@ func mustMakeAccountAndDevice(
appServiceID = util.RandomString(16) appServiceID = util.RandomString(16)
} }
_, err := accDB.InsertAccount(ctx, nil, localpart, "", appServiceID, accType) _, err := accDB.InsertAccount(ctx, nil, localpart, serverName, "", appServiceID, accType)
if err != nil { if err != nil {
t.Fatalf("unable to create account: %v", err) t.Fatalf("unable to create account: %v", err)
} }
_, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, util.RandomString(16), nil, "", userAgent) _, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, serverName, util.RandomString(16), nil, "", userAgent)
if err != nil { if err != nil {
t.Fatalf("unable to create device: %v", err) t.Fatalf("unable to create device: %v", err)
} }
@ -150,12 +151,12 @@ func Test_UserStatistics(t *testing.T) {
}) })
t.Run("Want Users", func(t *testing.T) { t.Run("Want Users", func(t *testing.T) {
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", api.AccountTypeUser, "Element Android") mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", "localhost", api.AccountTypeUser, "Element Android")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", api.AccountTypeUser, "Element iOS") mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", "localhost", api.AccountTypeUser, "Element iOS")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", api.AccountTypeUser, "Element web") mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", "localhost", api.AccountTypeUser, "Element web")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", api.AccountTypeGuest, "Element Electron") mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", "localhost", api.AccountTypeGuest, "Element Electron")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", api.AccountTypeAdmin, "gecko") mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", "localhost", api.AccountTypeAdmin, "gecko")
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", api.AccountTypeAppService, "gecko") mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", "localhost", api.AccountTypeAppService, "gecko")
gotStats, _, err := statsDB.UserStatistics(ctx, nil) gotStats, _, err := statsDB.UserStatistics(ctx, nil)
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)

View file

@ -80,14 +80,14 @@ func TestQueryProfile(t *testing.T) {
// only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added // only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite) userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite)
defer close() defer close()
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", api.AccountTypeUser) _, err := accountDB.CreateAccount(context.TODO(), "alice", serverName, "foobar", "", api.AccountTypeUser)
if err != nil { if err != nil {
t.Fatalf("failed to make account: %s", err) t.Fatalf("failed to make account: %s", err)
} }
if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil { if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", serverName, aliceAvatarURL); err != nil {
t.Fatalf("failed to set avatar url: %s", err) t.Fatalf("failed to set avatar url: %s", err)
} }
if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil { if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", serverName, aliceDisplayName); err != nil {
t.Fatalf("failed to set display name: %s", err) t.Fatalf("failed to set display name: %s", err)
} }
@ -164,7 +164,7 @@ func TestPasswordlessLoginFails(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
defer close() defer close()
_, err := accountDB.CreateAccount(ctx, "auser", "", "", api.AccountTypeAppService) _, err := accountDB.CreateAccount(ctx, "auser", serverName, "", "", api.AccountTypeAppService)
if err != nil { if err != nil {
t.Fatalf("failed to make account: %s", err) t.Fatalf("failed to make account: %s", err)
} }
@ -190,7 +190,7 @@ func TestLoginToken(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
defer close() defer close()
_, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", api.AccountTypeUser) _, err := accountDB.CreateAccount(ctx, "auser", serverName, "apassword", "", api.AccountTypeUser)
if err != nil { if err != nil {
t.Fatalf("failed to make account: %s", err) t.Fatalf("failed to make account: %s", err)
} }

View file

@ -2,10 +2,12 @@ package util
import ( import (
"context" "context"
"fmt"
"github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -17,10 +19,10 @@ type PusherDevice struct {
} }
// GetPushDevices pushes to the configured devices of a local user. // GetPushDevices pushes to the configured devices of a local user.
func GetPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) { func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
pushers, err := db.GetPushers(ctx, localpart) pushers, err := db.GetPushers(ctx, localpart, serverName)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("db.GetPushers: %w", err)
} }
devices := make([]*PusherDevice, 0, len(pushers)) devices := make([]*PusherDevice, 0, len(pushers))

View file

@ -8,6 +8,7 @@ import (
"github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -16,8 +17,8 @@ import (
// a single goroutine is started when talking to the Push // a single goroutine is started when talking to the Push
// gateways. There is no way to know when the background goroutine has // gateways. There is no way to know when the background goroutine has
// finished. // finished.
func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, db storage.Database) error { func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.Database) error {
pusherDevices, err := GetPushDevices(ctx, localpart, nil, db) pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db)
if err != nil { if err != nil {
return err return err
} }
@ -26,7 +27,7 @@ func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, loc
return nil return nil
} }
userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, tables.AllNotifications) userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, serverName, tables.AllNotifications)
if err != nil { if err != nil {
return err return err
} }