mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-27 07:28:27 +00:00
Virtual hosting schema and logic changes (#2876)
Note that virtual users cannot federate correctly yet.
This commit is contained in:
parent
e177e0ae73
commit
529df30b56
62 changed files with 1250 additions and 732 deletions
|
@ -32,6 +32,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
"github.com/matrix-org/dendrite/setup/base"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AddInternalRoutes registers HTTP handlers for internal API calls
|
// AddInternalRoutes registers HTTP handlers for internal API calls
|
||||||
|
@ -74,7 +75,7 @@ func NewInternalAPI(
|
||||||
// events to be sent out.
|
// events to be sent out.
|
||||||
for _, appservice := range base.Cfg.Derived.ApplicationServices {
|
for _, appservice := range base.Cfg.Derived.ApplicationServices {
|
||||||
// Create bot account for this AS if it doesn't already exist
|
// Create bot account for this AS if it doesn't already exist
|
||||||
if err := generateAppServiceAccount(userAPI, appservice); err != nil {
|
if err := generateAppServiceAccount(userAPI, appservice, base.Cfg.Global.ServerName); err != nil {
|
||||||
logrus.WithFields(logrus.Fields{
|
logrus.WithFields(logrus.Fields{
|
||||||
"appservice": appservice.ID,
|
"appservice": appservice.ID,
|
||||||
}).WithError(err).Panicf("failed to generate bot account for appservice")
|
}).WithError(err).Panicf("failed to generate bot account for appservice")
|
||||||
|
@ -101,11 +102,13 @@ func NewInternalAPI(
|
||||||
func generateAppServiceAccount(
|
func generateAppServiceAccount(
|
||||||
userAPI userapi.AppserviceUserAPI,
|
userAPI userapi.AppserviceUserAPI,
|
||||||
as config.ApplicationService,
|
as config.ApplicationService,
|
||||||
|
serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
var accRes userapi.PerformAccountCreationResponse
|
var accRes userapi.PerformAccountCreationResponse
|
||||||
err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{
|
err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{
|
||||||
AccountType: userapi.AccountTypeAppService,
|
AccountType: userapi.AccountTypeAppService,
|
||||||
Localpart: as.SenderLocalpart,
|
Localpart: as.SenderLocalpart,
|
||||||
|
ServerName: serverName,
|
||||||
AppServiceID: as.ID,
|
AppServiceID: as.ID,
|
||||||
OnConflict: userapi.ConflictUpdate,
|
OnConflict: userapi.ConflictUpdate,
|
||||||
}, &accRes)
|
}, &accRes)
|
||||||
|
@ -115,6 +118,7 @@ func generateAppServiceAccount(
|
||||||
var devRes userapi.PerformDeviceCreationResponse
|
var devRes userapi.PerformDeviceCreationResponse
|
||||||
err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{
|
err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{
|
||||||
Localpart: as.SenderLocalpart,
|
Localpart: as.SenderLocalpart,
|
||||||
|
ServerName: serverName,
|
||||||
AccessToken: as.ASToken,
|
AccessToken: as.ASToken,
|
||||||
DeviceID: &as.SenderLocalpart,
|
DeviceID: &as.SenderLocalpart,
|
||||||
DeviceDisplayName: &as.SenderLocalpart,
|
DeviceDisplayName: &as.SenderLocalpart,
|
||||||
|
|
|
@ -61,7 +61,7 @@ func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte)
|
||||||
|
|
||||||
func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) {
|
func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) {
|
||||||
r := req.(*PasswordRequest)
|
r := req.(*PasswordRequest)
|
||||||
username := strings.ToLower(r.Username())
|
username := r.Username()
|
||||||
if username == "" {
|
if username == "" {
|
||||||
return nil, &util.JSONResponse{
|
return nil, &util.JSONResponse{
|
||||||
Code: http.StatusUnauthorized,
|
Code: http.StatusUnauthorized,
|
||||||
|
@ -74,32 +74,43 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login,
|
||||||
JSON: jsonerror.BadJSON("A password must be supplied."),
|
JSON: jsonerror.BadJSON("A password must be supplied."),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
localpart, _, err := userutil.ParseUsernameParam(username, t.Config.Matrix)
|
localpart, domain, err := userutil.ParseUsernameParam(username, t.Config.Matrix)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &util.JSONResponse{
|
return nil, &util.JSONResponse{
|
||||||
Code: http.StatusUnauthorized,
|
Code: http.StatusUnauthorized,
|
||||||
JSON: jsonerror.InvalidUsername(err.Error()),
|
JSON: jsonerror.InvalidUsername(err.Error()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if !t.Config.Matrix.IsLocalServerName(domain) {
|
||||||
|
return nil, &util.JSONResponse{
|
||||||
|
Code: http.StatusUnauthorized,
|
||||||
|
JSON: jsonerror.InvalidUsername("The server name is not known."),
|
||||||
|
}
|
||||||
|
}
|
||||||
// Squash username to all lowercase letters
|
// Squash username to all lowercase letters
|
||||||
res := &api.QueryAccountByPasswordResponse{}
|
res := &api.QueryAccountByPasswordResponse{}
|
||||||
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: strings.ToLower(localpart), PlaintextPassword: r.Password}, res)
|
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
|
||||||
|
Localpart: strings.ToLower(localpart),
|
||||||
|
ServerName: domain,
|
||||||
|
PlaintextPassword: r.Password,
|
||||||
|
}, res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &util.JSONResponse{
|
return nil, &util.JSONResponse{
|
||||||
Code: http.StatusInternalServerError,
|
Code: http.StatusInternalServerError,
|
||||||
JSON: jsonerror.Unknown("unable to fetch account by password"),
|
JSON: jsonerror.Unknown("Unable to fetch account by password."),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !res.Exists {
|
if !res.Exists {
|
||||||
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
|
err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
|
ServerName: domain,
|
||||||
PlaintextPassword: r.Password,
|
PlaintextPassword: r.Password,
|
||||||
}, res)
|
}, res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &util.JSONResponse{
|
return nil, &util.JSONResponse{
|
||||||
Code: http.StatusInternalServerError,
|
Code: http.StatusInternalServerError,
|
||||||
JSON: jsonerror.Unknown("unable to fetch account by password"),
|
JSON: jsonerror.Unknown("Unable to fetch account by password."),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows
|
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows
|
||||||
|
|
|
@ -102,6 +102,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
|
serverName := cfg.Matrix.ServerName
|
||||||
localpart, ok := vars["localpart"]
|
localpart, ok := vars["localpart"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
@ -109,6 +110,9 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
|
||||||
JSON: jsonerror.MissingArgument("Expecting user localpart."),
|
JSON: jsonerror.MissingArgument("Expecting user localpart."),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if l, s, err := gomatrixserverlib.SplitID('@', localpart); err == nil {
|
||||||
|
localpart, serverName = l, s
|
||||||
|
}
|
||||||
request := struct {
|
request := struct {
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
}{}
|
}{}
|
||||||
|
@ -126,6 +130,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap
|
||||||
}
|
}
|
||||||
updateReq := &userapi.PerformPasswordUpdateRequest{
|
updateReq := &userapi.PerformPasswordUpdateRequest{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
|
ServerName: serverName,
|
||||||
Password: request.Password,
|
Password: request.Password,
|
||||||
LogoutDevices: true,
|
LogoutDevices: true,
|
||||||
}
|
}
|
||||||
|
|
|
@ -100,6 +100,7 @@ func completeAuth(
|
||||||
DeviceID: login.DeviceID,
|
DeviceID: login.DeviceID,
|
||||||
AccessToken: token,
|
AccessToken: token,
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
|
ServerName: serverName,
|
||||||
IPAddr: ipAddr,
|
IPAddr: ipAddr,
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
}, &performRes)
|
}, &performRes)
|
||||||
|
|
|
@ -40,16 +40,17 @@ func GetNotifications(
|
||||||
}
|
}
|
||||||
|
|
||||||
var queryRes userapi.QueryNotificationsResponse
|
var queryRes userapi.QueryNotificationsResponse
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{
|
err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
From: req.URL.Query().Get("from"),
|
ServerName: domain,
|
||||||
Limit: int(limit),
|
From: req.URL.Query().Get("from"),
|
||||||
Only: req.URL.Query().Get("only"),
|
Limit: int(limit),
|
||||||
|
Only: req.URL.Query().Get("only"),
|
||||||
}, &queryRes)
|
}, &queryRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("QueryNotifications failed")
|
util.GetLogger(req.Context()).WithError(err).Error("QueryNotifications failed")
|
||||||
|
|
|
@ -86,7 +86,7 @@ func Password(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the local part.
|
// Get the local part.
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
|
@ -94,8 +94,9 @@ func Password(
|
||||||
|
|
||||||
// Ask the user API to perform the password change.
|
// Ask the user API to perform the password change.
|
||||||
passwordReq := &api.PerformPasswordUpdateRequest{
|
passwordReq := &api.PerformPasswordUpdateRequest{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
Password: r.NewPassword,
|
ServerName: domain,
|
||||||
|
Password: r.NewPassword,
|
||||||
}
|
}
|
||||||
passwordRes := &api.PerformPasswordUpdateResponse{}
|
passwordRes := &api.PerformPasswordUpdateResponse{}
|
||||||
if err := userAPI.PerformPasswordUpdate(req.Context(), passwordReq, passwordRes); err != nil {
|
if err := userAPI.PerformPasswordUpdate(req.Context(), passwordReq, passwordRes); err != nil {
|
||||||
|
@ -122,8 +123,9 @@ func Password(
|
||||||
}
|
}
|
||||||
|
|
||||||
pushersReq := &api.PerformPusherDeletionRequest{
|
pushersReq := &api.PerformPusherDeletionRequest{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
SessionID: device.SessionID,
|
ServerName: domain,
|
||||||
|
SessionID: device.SessionID,
|
||||||
}
|
}
|
||||||
if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil {
|
if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed")
|
util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed")
|
||||||
|
|
|
@ -31,13 +31,14 @@ func GetPushers(
|
||||||
userAPI userapi.ClientUserAPI,
|
userAPI userapi.ClientUserAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
var queryRes userapi.QueryPushersResponse
|
var queryRes userapi.QueryPushersResponse
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{
|
err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
|
ServerName: domain,
|
||||||
}, &queryRes)
|
}, &queryRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed")
|
util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed")
|
||||||
|
@ -59,7 +60,7 @@ func SetPusher(
|
||||||
req *http.Request, device *userapi.Device,
|
req *http.Request, device *userapi.Device,
|
||||||
userAPI userapi.ClientUserAPI,
|
userAPI userapi.ClientUserAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
|
@ -93,6 +94,7 @@ func SetPusher(
|
||||||
|
|
||||||
}
|
}
|
||||||
body.Localpart = localpart
|
body.Localpart = localpart
|
||||||
|
body.ServerName = domain
|
||||||
body.SessionID = device.SessionID
|
body.SessionID = device.SessionID
|
||||||
err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{})
|
err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -588,12 +588,15 @@ func Register(
|
||||||
}
|
}
|
||||||
// Auto generate a numeric username if r.Username is empty
|
// Auto generate a numeric username if r.Username is empty
|
||||||
if r.Username == "" {
|
if r.Username == "" {
|
||||||
res := &userapi.QueryNumericLocalpartResponse{}
|
nreq := &userapi.QueryNumericLocalpartRequest{
|
||||||
if err := userAPI.QueryNumericLocalpart(req.Context(), res); err != nil {
|
ServerName: cfg.Matrix.ServerName, // TODO: might not be right
|
||||||
|
}
|
||||||
|
nres := &userapi.QueryNumericLocalpartResponse{}
|
||||||
|
if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
|
util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
r.Username = strconv.FormatInt(res.ID, 10)
|
r.Username = strconv.FormatInt(nres.ID, 10)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Is this an appservice registration? It will be if the access
|
// Is this an appservice registration? It will be if the access
|
||||||
|
@ -676,6 +679,7 @@ func handleGuestRegistration(
|
||||||
var devRes userapi.PerformDeviceCreationResponse
|
var devRes userapi.PerformDeviceCreationResponse
|
||||||
err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{
|
err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{
|
||||||
Localpart: res.Account.Localpart,
|
Localpart: res.Account.Localpart,
|
||||||
|
ServerName: res.Account.ServerName,
|
||||||
DeviceDisplayName: r.InitialDisplayName,
|
DeviceDisplayName: r.InitialDisplayName,
|
||||||
AccessToken: token,
|
AccessToken: token,
|
||||||
IPAddr: req.RemoteAddr,
|
IPAddr: req.RemoteAddr,
|
||||||
|
|
|
@ -157,7 +157,7 @@ func Setup(
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
dendriteAdminRouter.Handle("/admin/resetPassword/{localpart}",
|
dendriteAdminRouter.Handle("/admin/resetPassword/{userID}",
|
||||||
httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
return AdminResetPassword(req, cfg, device, userAPI)
|
return AdminResetPassword(req, cfg, device, userAPI)
|
||||||
}),
|
}),
|
||||||
|
|
|
@ -286,6 +286,7 @@ func getSenderDevice(
|
||||||
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
|
err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{
|
||||||
AccountType: userapi.AccountTypeUser,
|
AccountType: userapi.AccountTypeUser,
|
||||||
Localpart: cfg.Matrix.ServerNotices.LocalPart,
|
Localpart: cfg.Matrix.ServerNotices.LocalPart,
|
||||||
|
ServerName: cfg.Matrix.ServerName,
|
||||||
OnConflict: userapi.ConflictUpdate,
|
OnConflict: userapi.ConflictUpdate,
|
||||||
}, &accRes)
|
}, &accRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -295,8 +296,9 @@ func getSenderDevice(
|
||||||
// Set the avatarurl for the user
|
// Set the avatarurl for the user
|
||||||
avatarRes := &userapi.PerformSetAvatarURLResponse{}
|
avatarRes := &userapi.PerformSetAvatarURLResponse{}
|
||||||
if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{
|
if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{
|
||||||
Localpart: cfg.Matrix.ServerNotices.LocalPart,
|
Localpart: cfg.Matrix.ServerNotices.LocalPart,
|
||||||
AvatarURL: cfg.Matrix.ServerNotices.AvatarURL,
|
ServerName: cfg.Matrix.ServerName,
|
||||||
|
AvatarURL: cfg.Matrix.ServerNotices.AvatarURL,
|
||||||
}, avatarRes); err != nil {
|
}, avatarRes); err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed")
|
util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed")
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -308,6 +310,7 @@ func getSenderDevice(
|
||||||
displayNameRes := &userapi.PerformUpdateDisplayNameResponse{}
|
displayNameRes := &userapi.PerformUpdateDisplayNameResponse{}
|
||||||
if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{
|
if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{
|
||||||
Localpart: cfg.Matrix.ServerNotices.LocalPart,
|
Localpart: cfg.Matrix.ServerNotices.LocalPart,
|
||||||
|
ServerName: cfg.Matrix.ServerName,
|
||||||
DisplayName: cfg.Matrix.ServerNotices.DisplayName,
|
DisplayName: cfg.Matrix.ServerNotices.DisplayName,
|
||||||
}, displayNameRes); err != nil {
|
}, displayNameRes); err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed")
|
util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed")
|
||||||
|
@ -353,6 +356,7 @@ func getSenderDevice(
|
||||||
var devRes userapi.PerformDeviceCreationResponse
|
var devRes userapi.PerformDeviceCreationResponse
|
||||||
err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{
|
err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{
|
||||||
Localpart: cfg.Matrix.ServerNotices.LocalPart,
|
Localpart: cfg.Matrix.ServerNotices.LocalPart,
|
||||||
|
ServerName: cfg.Matrix.ServerName,
|
||||||
DeviceDisplayName: &cfg.Matrix.ServerNotices.LocalPart,
|
DeviceDisplayName: &cfg.Matrix.ServerNotices.LocalPart,
|
||||||
AccessToken: token,
|
AccessToken: token,
|
||||||
NoDeviceListUpdate: true,
|
NoDeviceListUpdate: true,
|
||||||
|
|
|
@ -136,16 +136,17 @@ func CheckAndSave3PIDAssociation(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save the association in the database
|
// Save the association in the database
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = threePIDAPI.PerformSaveThreePIDAssociation(req.Context(), &api.PerformSaveThreePIDAssociationRequest{
|
if err = threePIDAPI.PerformSaveThreePIDAssociation(req.Context(), &api.PerformSaveThreePIDAssociationRequest{
|
||||||
ThreePID: address,
|
ThreePID: address,
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
Medium: medium,
|
ServerName: domain,
|
||||||
|
Medium: medium,
|
||||||
}, &struct{}{}); err != nil {
|
}, &struct{}{}); err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.PerformSaveThreePIDAssociation failed")
|
util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.PerformSaveThreePIDAssociation failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
|
@ -161,7 +162,7 @@ func CheckAndSave3PIDAssociation(
|
||||||
func GetAssociated3PIDs(
|
func GetAssociated3PIDs(
|
||||||
req *http.Request, threepidAPI api.ClientUserAPI, device *api.Device,
|
req *http.Request, threepidAPI api.ClientUserAPI, device *api.Device,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
|
@ -169,7 +170,8 @@ func GetAssociated3PIDs(
|
||||||
|
|
||||||
res := &api.QueryThreePIDsForLocalpartResponse{}
|
res := &api.QueryThreePIDsForLocalpartResponse{}
|
||||||
err = threepidAPI.QueryThreePIDsForLocalpart(req.Context(), &api.QueryThreePIDsForLocalpartRequest{
|
err = threepidAPI.QueryThreePIDsForLocalpart(req.Context(), &api.QueryThreePIDsForLocalpartRequest{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
|
ServerName: domain,
|
||||||
}, res)
|
}, res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.QueryThreePIDsForLocalpart failed")
|
util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.QueryThreePIDsForLocalpart failed")
|
||||||
|
|
|
@ -120,15 +120,23 @@ func NewInternalAPI(
|
||||||
|
|
||||||
js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
|
js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
|
||||||
|
|
||||||
|
signingInfo := map[gomatrixserverlib.ServerName]*queue.SigningInfo{}
|
||||||
|
for _, serverName := range append(
|
||||||
|
[]gomatrixserverlib.ServerName{base.Cfg.Global.ServerName},
|
||||||
|
base.Cfg.Global.SecondaryServerNames...,
|
||||||
|
) {
|
||||||
|
signingInfo[serverName] = &queue.SigningInfo{
|
||||||
|
KeyID: cfg.Matrix.KeyID,
|
||||||
|
PrivateKey: cfg.Matrix.PrivateKey,
|
||||||
|
ServerName: serverName,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
queues := queue.NewOutgoingQueues(
|
queues := queue.NewOutgoingQueues(
|
||||||
federationDB, base.ProcessContext,
|
federationDB, base.ProcessContext,
|
||||||
cfg.Matrix.DisableFederation,
|
cfg.Matrix.DisableFederation,
|
||||||
cfg.Matrix.ServerName, federation, rsAPI, &stats,
|
cfg.Matrix.ServerName, federation, rsAPI, &stats,
|
||||||
&queue.SigningInfo{
|
signingInfo,
|
||||||
KeyID: cfg.Matrix.KeyID,
|
|
||||||
PrivateKey: cfg.Matrix.PrivateKey,
|
|
||||||
ServerName: cfg.Matrix.ServerName,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
rsConsumer := consumers.NewOutputRoomEventConsumer(
|
rsConsumer := consumers.NewOutputRoomEventConsumer(
|
||||||
|
|
|
@ -137,7 +137,7 @@ func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the keys and JSON-ify them.
|
// Get the keys and JSON-ify them.
|
||||||
keys := routing.LocalKeys(s.config)
|
keys := routing.LocalKeys(s.config, gomatrixserverlib.ServerName(req.Host))
|
||||||
body, err := json.MarshalIndent(keys.JSON, "", " ")
|
body, err := json.MarshalIndent(keys.JSON, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -50,7 +50,7 @@ type destinationQueue struct {
|
||||||
queues *OutgoingQueues
|
queues *OutgoingQueues
|
||||||
db storage.Database
|
db storage.Database
|
||||||
process *process.ProcessContext
|
process *process.ProcessContext
|
||||||
signing *SigningInfo
|
signing map[gomatrixserverlib.ServerName]*SigningInfo
|
||||||
rsAPI api.FederationRoomserverAPI
|
rsAPI api.FederationRoomserverAPI
|
||||||
client fedapi.FederationClient // federation client
|
client fedapi.FederationClient // federation client
|
||||||
origin gomatrixserverlib.ServerName // origin of requests
|
origin gomatrixserverlib.ServerName // origin of requests
|
||||||
|
|
|
@ -46,7 +46,7 @@ type OutgoingQueues struct {
|
||||||
origin gomatrixserverlib.ServerName
|
origin gomatrixserverlib.ServerName
|
||||||
client fedapi.FederationClient
|
client fedapi.FederationClient
|
||||||
statistics *statistics.Statistics
|
statistics *statistics.Statistics
|
||||||
signing *SigningInfo
|
signing map[gomatrixserverlib.ServerName]*SigningInfo
|
||||||
queuesMutex sync.Mutex // protects the below
|
queuesMutex sync.Mutex // protects the below
|
||||||
queues map[gomatrixserverlib.ServerName]*destinationQueue
|
queues map[gomatrixserverlib.ServerName]*destinationQueue
|
||||||
}
|
}
|
||||||
|
@ -91,7 +91,7 @@ func NewOutgoingQueues(
|
||||||
client fedapi.FederationClient,
|
client fedapi.FederationClient,
|
||||||
rsAPI api.FederationRoomserverAPI,
|
rsAPI api.FederationRoomserverAPI,
|
||||||
statistics *statistics.Statistics,
|
statistics *statistics.Statistics,
|
||||||
signing *SigningInfo,
|
signing map[gomatrixserverlib.ServerName]*SigningInfo,
|
||||||
) *OutgoingQueues {
|
) *OutgoingQueues {
|
||||||
queues := &OutgoingQueues{
|
queues := &OutgoingQueues{
|
||||||
disabled: disabled,
|
disabled: disabled,
|
||||||
|
@ -199,11 +199,10 @@ func (oqs *OutgoingQueues) SendEvent(
|
||||||
log.Trace("Federation is disabled, not sending event")
|
log.Trace("Federation is disabled, not sending event")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if origin != oqs.origin {
|
if _, ok := oqs.signing[origin]; !ok {
|
||||||
// TODO: Support virtual hosting; gh issue #577.
|
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
"sendevent: unexpected server to send as: got %q expected %q",
|
"sendevent: unexpected server to send as %q",
|
||||||
origin, oqs.origin,
|
origin,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -214,7 +213,9 @@ func (oqs *OutgoingQueues) SendEvent(
|
||||||
destmap[d] = struct{}{}
|
destmap[d] = struct{}{}
|
||||||
}
|
}
|
||||||
delete(destmap, oqs.origin)
|
delete(destmap, oqs.origin)
|
||||||
delete(destmap, oqs.signing.ServerName)
|
for local := range oqs.signing {
|
||||||
|
delete(destmap, local)
|
||||||
|
}
|
||||||
|
|
||||||
// Check if any of the destinations are prohibited by server ACLs.
|
// Check if any of the destinations are prohibited by server ACLs.
|
||||||
for destination := range destmap {
|
for destination := range destmap {
|
||||||
|
@ -288,11 +289,10 @@ func (oqs *OutgoingQueues) SendEDU(
|
||||||
log.Trace("Federation is disabled, not sending EDU")
|
log.Trace("Federation is disabled, not sending EDU")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if origin != oqs.origin {
|
if _, ok := oqs.signing[origin]; !ok {
|
||||||
// TODO: Support virtual hosting; gh issue #577.
|
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
"sendevent: unexpected server to send as: got %q expected %q",
|
"sendevent: unexpected server to send as %q",
|
||||||
origin, oqs.origin,
|
origin,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -303,7 +303,9 @@ func (oqs *OutgoingQueues) SendEDU(
|
||||||
destmap[d] = struct{}{}
|
destmap[d] = struct{}{}
|
||||||
}
|
}
|
||||||
delete(destmap, oqs.origin)
|
delete(destmap, oqs.origin)
|
||||||
delete(destmap, oqs.signing.ServerName)
|
for local := range oqs.signing {
|
||||||
|
delete(destmap, local)
|
||||||
|
}
|
||||||
|
|
||||||
// There is absolutely no guarantee that the EDU will have a room_id
|
// There is absolutely no guarantee that the EDU will have a room_id
|
||||||
// field, as it is not required by the spec. However, if it *does*
|
// field, as it is not required by the spec. However, if it *does*
|
||||||
|
|
|
@ -350,10 +350,12 @@ func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T
|
||||||
}
|
}
|
||||||
rs := &stubFederationRoomServerAPI{}
|
rs := &stubFederationRoomServerAPI{}
|
||||||
stats := statistics.NewStatistics(db, failuresUntilBlacklist)
|
stats := statistics.NewStatistics(db, failuresUntilBlacklist)
|
||||||
signingInfo := &SigningInfo{
|
signingInfo := map[gomatrixserverlib.ServerName]*SigningInfo{
|
||||||
KeyID: "ed21019:auto",
|
"localhost": {
|
||||||
PrivateKey: test.PrivateKeyA,
|
KeyID: "ed21019:auto",
|
||||||
ServerName: "localhost",
|
PrivateKey: test.PrivateKeyA,
|
||||||
|
ServerName: "localhost",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
queues := NewOutgoingQueues(db, processContext, false, "localhost", fc, rs, &stats, signingInfo)
|
queues := NewOutgoingQueues(db, processContext, false, "localhost", fc, rs, &stats, signingInfo)
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@ package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -134,18 +135,21 @@ func ClaimOneTimeKeys(
|
||||||
|
|
||||||
// LocalKeys returns the local keys for the server.
|
// LocalKeys returns the local keys for the server.
|
||||||
// See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys
|
// See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys
|
||||||
func LocalKeys(cfg *config.FederationAPI) util.JSONResponse {
|
func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) util.JSONResponse {
|
||||||
keys, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod))
|
keys, err := localKeys(cfg, serverName, time.Now().Add(cfg.Matrix.KeyValidityPeriod))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: keys}
|
return util.JSONResponse{Code: http.StatusOK, JSON: keys}
|
||||||
}
|
}
|
||||||
|
|
||||||
func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) {
|
func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) {
|
||||||
var keys gomatrixserverlib.ServerKeys
|
var keys gomatrixserverlib.ServerKeys
|
||||||
|
if !cfg.Matrix.IsLocalServerName(serverName) {
|
||||||
|
return nil, fmt.Errorf("server name not known")
|
||||||
|
}
|
||||||
|
|
||||||
keys.ServerName = cfg.Matrix.ServerName
|
keys.ServerName = serverName
|
||||||
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(validUntil)
|
keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(validUntil)
|
||||||
|
|
||||||
publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)
|
publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)
|
||||||
|
@ -172,7 +176,7 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver
|
||||||
}
|
}
|
||||||
|
|
||||||
keys.Raw, err = gomatrixserverlib.SignJSON(
|
keys.Raw, err = gomatrixserverlib.SignJSON(
|
||||||
string(cfg.Matrix.ServerName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, toSign,
|
string(serverName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, toSign,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -186,6 +190,14 @@ func NotaryKeys(
|
||||||
fsAPI federationAPI.FederationInternalAPI,
|
fsAPI federationAPI.FederationInternalAPI,
|
||||||
req *gomatrixserverlib.PublicKeyNotaryLookupRequest,
|
req *gomatrixserverlib.PublicKeyNotaryLookupRequest,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
serverName := gomatrixserverlib.ServerName(httpReq.Host) // TODO: this is not ideal
|
||||||
|
if !cfg.Matrix.IsLocalServerName(serverName) {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusNotFound,
|
||||||
|
JSON: jsonerror.NotFound("Server name not known"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if req == nil {
|
if req == nil {
|
||||||
req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{}
|
req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{}
|
||||||
if reqErr := clienthttputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil {
|
if reqErr := clienthttputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil {
|
||||||
|
@ -201,7 +213,7 @@ func NotaryKeys(
|
||||||
for serverName, kidToCriteria := range req.ServerKeys {
|
for serverName, kidToCriteria := range req.ServerKeys {
|
||||||
var keyList []gomatrixserverlib.ServerKeys
|
var keyList []gomatrixserverlib.ServerKeys
|
||||||
if serverName == cfg.Matrix.ServerName {
|
if serverName == cfg.Matrix.ServerName {
|
||||||
if k, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil {
|
if k, err := localKeys(cfg, serverName, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil {
|
||||||
keyList = append(keyList, *k)
|
keyList = append(keyList, *k)
|
||||||
} else {
|
} else {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
|
|
|
@ -74,7 +74,7 @@ func Setup(
|
||||||
}
|
}
|
||||||
|
|
||||||
localKeys := httputil.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse {
|
localKeys := httputil.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse {
|
||||||
return LocalKeys(cfg)
|
return LocalKeys(cfg, gomatrixserverlib.ServerName(req.Host))
|
||||||
})
|
})
|
||||||
|
|
||||||
notaryKeys := httputil.MakeExternalAPI("notarykeys", func(req *http.Request) util.JSONResponse {
|
notaryKeys := httputil.MakeExternalAPI("notarykeys", func(req *http.Request) util.JSONResponse {
|
||||||
|
|
|
@ -33,16 +33,17 @@ import (
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/keyserver/producers"
|
"github.com/matrix-org/dendrite/keyserver/producers"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage"
|
"github.com/matrix-org/dendrite/keyserver/storage"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
type KeyInternalAPI struct {
|
type KeyInternalAPI struct {
|
||||||
DB storage.Database
|
DB storage.Database
|
||||||
ThisServer gomatrixserverlib.ServerName
|
Cfg *config.KeyServer
|
||||||
FedClient fedsenderapi.KeyserverFederationAPI
|
FedClient fedsenderapi.KeyserverFederationAPI
|
||||||
UserAPI userapi.KeyserverUserAPI
|
UserAPI userapi.KeyserverUserAPI
|
||||||
Producer *producers.KeyChange
|
Producer *producers.KeyChange
|
||||||
Updater *DeviceListUpdater
|
Updater *DeviceListUpdater
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) {
|
func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) {
|
||||||
|
@ -95,8 +96,11 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
|
||||||
nested[userID] = val
|
nested[userID] = val
|
||||||
domainToDeviceKeys[string(serverName)] = nested
|
domainToDeviceKeys[string(serverName)] = nested
|
||||||
}
|
}
|
||||||
// claim local keys
|
for domain, local := range domainToDeviceKeys {
|
||||||
if local, ok := domainToDeviceKeys[string(a.ThisServer)]; ok {
|
if !a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// claim local keys
|
||||||
keys, err := a.DB.ClaimKeys(ctx, local)
|
keys, err := a.DB.ClaimKeys(ctx, local)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = &api.KeyError{
|
res.Error = &api.KeyError{
|
||||||
|
@ -117,7 +121,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
|
||||||
res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON
|
res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
delete(domainToDeviceKeys, string(a.ThisServer))
|
delete(domainToDeviceKeys, domain)
|
||||||
}
|
}
|
||||||
if len(domainToDeviceKeys) > 0 {
|
if len(domainToDeviceKeys) > 0 {
|
||||||
a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
|
a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
|
||||||
|
@ -258,7 +262,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
||||||
}
|
}
|
||||||
domain := string(serverName)
|
domain := string(serverName)
|
||||||
// query local devices
|
// query local devices
|
||||||
if serverName == a.ThisServer {
|
if a.Cfg.Matrix.IsLocalServerName(serverName) {
|
||||||
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
|
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = &api.KeyError{
|
res.Error = &api.KeyError{
|
||||||
|
@ -437,13 +441,13 @@ func (a *KeyInternalAPI) queryRemoteKeys(
|
||||||
|
|
||||||
domains := map[string]struct{}{}
|
domains := map[string]struct{}{}
|
||||||
for domain := range domainToDeviceKeys {
|
for domain := range domainToDeviceKeys {
|
||||||
if domain == string(a.ThisServer) {
|
if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
domains[domain] = struct{}{}
|
domains[domain] = struct{}{}
|
||||||
}
|
}
|
||||||
for domain := range domainToCrossSigningKeys {
|
for domain := range domainToCrossSigningKeys {
|
||||||
if domain == string(a.ThisServer) {
|
if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
domains[domain] = struct{}{}
|
domains[domain] = struct{}{}
|
||||||
|
@ -689,7 +693,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue // ignore invalid users
|
continue // ignore invalid users
|
||||||
}
|
}
|
||||||
if serverName != a.ThisServer {
|
if !a.Cfg.Matrix.IsLocalServerName(serverName) {
|
||||||
continue // ignore remote users
|
continue // ignore remote users
|
||||||
}
|
}
|
||||||
if len(key.KeyJSON) == 0 {
|
if len(key.KeyJSON) == 0 {
|
||||||
|
|
|
@ -53,10 +53,10 @@ func NewInternalAPI(
|
||||||
DB: db,
|
DB: db,
|
||||||
}
|
}
|
||||||
ap := &internal.KeyInternalAPI{
|
ap := &internal.KeyInternalAPI{
|
||||||
DB: db,
|
DB: db,
|
||||||
ThisServer: cfg.Matrix.ServerName,
|
Cfg: cfg,
|
||||||
FedClient: fedClient,
|
FedClient: fedClient,
|
||||||
Producer: keyChangeProducer,
|
Producer: keyChangeProducer,
|
||||||
}
|
}
|
||||||
updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable
|
updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable
|
||||||
ap.Updater = updater
|
ap.Updater = updater
|
||||||
|
|
|
@ -78,7 +78,7 @@ type ClientUserAPI interface {
|
||||||
QueryAcccessTokenAPI
|
QueryAcccessTokenAPI
|
||||||
LoginTokenInternalAPI
|
LoginTokenInternalAPI
|
||||||
UserLoginAPI
|
UserLoginAPI
|
||||||
QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error
|
QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error
|
||||||
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
||||||
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
||||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||||
|
@ -335,9 +335,10 @@ type PerformAccountCreationResponse struct {
|
||||||
|
|
||||||
// PerformAccountCreationRequest is the request for PerformAccountCreation
|
// PerformAccountCreationRequest is the request for PerformAccountCreation
|
||||||
type PerformPasswordUpdateRequest struct {
|
type PerformPasswordUpdateRequest struct {
|
||||||
Localpart string // Required: The localpart for this account.
|
Localpart string // Required: The localpart for this account.
|
||||||
Password string // Required: The new password to set.
|
ServerName gomatrixserverlib.ServerName // Required: The domain for this account.
|
||||||
LogoutDevices bool // Optional: Whether to log out all user devices.
|
Password string // Required: The new password to set.
|
||||||
|
LogoutDevices bool // Optional: Whether to log out all user devices.
|
||||||
}
|
}
|
||||||
|
|
||||||
// PerformAccountCreationResponse is the response for PerformAccountCreation
|
// PerformAccountCreationResponse is the response for PerformAccountCreation
|
||||||
|
@ -518,7 +519,8 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
type QueryPushersRequest struct {
|
type QueryPushersRequest struct {
|
||||||
Localpart string
|
Localpart string
|
||||||
|
ServerName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryPushersResponse struct {
|
type QueryPushersResponse struct {
|
||||||
|
@ -526,14 +528,16 @@ type QueryPushersResponse struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type PerformPusherSetRequest struct {
|
type PerformPusherSetRequest struct {
|
||||||
Pusher // Anonymous field because that's how clientapi unmarshals it.
|
Pusher // Anonymous field because that's how clientapi unmarshals it.
|
||||||
Localpart string
|
Localpart string
|
||||||
Append bool `json:"append"`
|
ServerName gomatrixserverlib.ServerName
|
||||||
|
Append bool `json:"append"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PerformPusherDeletionRequest struct {
|
type PerformPusherDeletionRequest struct {
|
||||||
Localpart string
|
Localpart string
|
||||||
SessionID int64
|
ServerName gomatrixserverlib.ServerName
|
||||||
|
SessionID int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pusher represents a push notification subscriber
|
// Pusher represents a push notification subscriber
|
||||||
|
@ -571,10 +575,11 @@ type QueryPushRulesResponse struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryNotificationsRequest struct {
|
type QueryNotificationsRequest struct {
|
||||||
Localpart string `json:"localpart"` // Required.
|
Localpart string `json:"localpart"` // Required.
|
||||||
From string `json:"from,omitempty"`
|
ServerName gomatrixserverlib.ServerName `json:"server_name"` // Required.
|
||||||
Limit int `json:"limit,omitempty"`
|
From string `json:"from,omitempty"`
|
||||||
Only string `json:"only,omitempty"`
|
Limit int `json:"limit,omitempty"`
|
||||||
|
Only string `json:"only,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryNotificationsResponse struct {
|
type QueryNotificationsResponse struct {
|
||||||
|
@ -601,12 +606,17 @@ type PerformSetAvatarURLResponse struct {
|
||||||
Changed bool `json:"changed"`
|
Changed bool `json:"changed"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type QueryNumericLocalpartRequest struct {
|
||||||
|
ServerName gomatrixserverlib.ServerName
|
||||||
|
}
|
||||||
|
|
||||||
type QueryNumericLocalpartResponse struct {
|
type QueryNumericLocalpartResponse struct {
|
||||||
ID int64
|
ID int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryAccountAvailabilityRequest struct {
|
type QueryAccountAvailabilityRequest struct {
|
||||||
Localpart string
|
Localpart string
|
||||||
|
ServerName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryAccountAvailabilityResponse struct {
|
type QueryAccountAvailabilityResponse struct {
|
||||||
|
@ -614,7 +624,9 @@ type QueryAccountAvailabilityResponse struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryAccountByPasswordRequest struct {
|
type QueryAccountByPasswordRequest struct {
|
||||||
Localpart, PlaintextPassword string
|
Localpart string
|
||||||
|
ServerName gomatrixserverlib.ServerName
|
||||||
|
PlaintextPassword string
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryAccountByPasswordResponse struct {
|
type QueryAccountByPasswordResponse struct {
|
||||||
|
@ -638,11 +650,13 @@ type QueryLocalpartForThreePIDRequest struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryLocalpartForThreePIDResponse struct {
|
type QueryLocalpartForThreePIDResponse struct {
|
||||||
Localpart string
|
Localpart string
|
||||||
|
ServerName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryThreePIDsForLocalpartRequest struct {
|
type QueryThreePIDsForLocalpartRequest struct {
|
||||||
Localpart string
|
Localpart string
|
||||||
|
ServerName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryThreePIDsForLocalpartResponse struct {
|
type QueryThreePIDsForLocalpartResponse struct {
|
||||||
|
@ -652,5 +666,8 @@ type QueryThreePIDsForLocalpartResponse struct {
|
||||||
type PerformForgetThreePIDRequest QueryLocalpartForThreePIDRequest
|
type PerformForgetThreePIDRequest QueryLocalpartForThreePIDRequest
|
||||||
|
|
||||||
type PerformSaveThreePIDAssociationRequest struct {
|
type PerformSaveThreePIDAssociationRequest struct {
|
||||||
ThreePID, Localpart, Medium string
|
ThreePID string
|
||||||
|
Localpart string
|
||||||
|
ServerName gomatrixserverlib.ServerName
|
||||||
|
Medium string
|
||||||
}
|
}
|
||||||
|
|
|
@ -156,8 +156,8 @@ func (t *UserInternalAPITrace) SetAvatarURL(ctx context.Context, req *PerformSet
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error {
|
func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error {
|
||||||
err := t.Impl.QueryNumericLocalpart(ctx, res)
|
err := t.Impl.QueryNumericLocalpart(ctx, req, res)
|
||||||
util.GetLogger(ctx).Infof("QueryNumericLocalpart req= res=%+v", js(res))
|
util.GetLogger(ctx).Infof("QueryNumericLocalpart req= res=%+v", js(res))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -104,7 +104,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true)
|
updated, err := s.db.SetNotificationsRead(ctx, localpart, domain, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("userapi EDU consumer")
|
log.WithError(err).Error("userapi EDU consumer")
|
||||||
return false
|
return false
|
||||||
|
@ -118,7 +118,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats
|
||||||
if !updated {
|
if !updated {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
|
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, domain, s.db); err != nil {
|
||||||
log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
|
log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
|
@ -192,25 +192,25 @@ func (s *OutputRoomEventConsumer) storeMessageStats(ctx context.Context, eventTy
|
||||||
func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoomID, newRoomID string, localMembers []*localMembership, roomSize int) error {
|
func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoomID, newRoomID string, localMembers []*localMembership, roomSize int) error {
|
||||||
for _, membership := range localMembers {
|
for _, membership := range localMembers {
|
||||||
// Copy any existing push rules from old -> new room
|
// Copy any existing push rules from old -> new room
|
||||||
if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil {
|
if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// preserve m.direct room state
|
// preserve m.direct room state
|
||||||
if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, roomSize); err != nil {
|
if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain, roomSize); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy existing m.tag entries, if any
|
// copy existing m.tag entries, if any
|
||||||
if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil {
|
if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string) error {
|
func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string, serverName gomatrixserverlib.ServerName) error {
|
||||||
pushRules, err := s.db.QueryPushRules(ctx, localpart)
|
pushRules, err := s.db.QueryPushRules(ctx, localpart, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to query pushrules for user: %w", err)
|
return fmt.Errorf("failed to query pushrules for user: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -229,7 +229,7 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err = s.db.SaveAccountData(ctx, localpart, "", "m.push_rules", rules); err != nil {
|
if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.push_rules", rules); err != nil {
|
||||||
return fmt.Errorf("failed to update pushrules: %w", err)
|
return fmt.Errorf("failed to update pushrules: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -237,13 +237,13 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID,
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateMDirect copies the "is_direct" flag from oldRoomID to newROomID
|
// updateMDirect copies the "is_direct" flag from oldRoomID to newROomID
|
||||||
func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, roomSize int) error {
|
func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName, roomSize int) error {
|
||||||
// this is most likely not a DM, so skip updating m.direct state
|
// this is most likely not a DM, so skip updating m.direct state
|
||||||
if roomSize > 2 {
|
if roomSize > 2 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// Get direct message state
|
// Get direct message state
|
||||||
directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, "", "m.direct")
|
directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, serverName, "", "m.direct")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get m.direct from database: %w", err)
|
return fmt.Errorf("failed to get m.direct from database: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -267,7 +267,7 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if err = s.db.SaveAccountData(ctx, localpart, "", "m.direct", data); err != nil {
|
if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.direct", data); err != nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -279,15 +279,15 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID,
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string) error {
|
func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName) error {
|
||||||
tag, err := s.db.GetAccountDataByType(ctx, localpart, oldRoomID, "m.tag")
|
tag, err := s.db.GetAccountDataByType(ctx, localpart, serverName, oldRoomID, "m.tag")
|
||||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if tag == nil {
|
if tag == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return s.db.SaveAccountData(ctx, localpart, newRoomID, "m.tag", tag)
|
return s.db.SaveAccountData(ctx, localpart, serverName, newRoomID, "m.tag", tag)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error {
|
func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error {
|
||||||
|
@ -492,11 +492,11 @@ func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, er
|
||||||
func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error {
|
func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error {
|
||||||
actions, err := s.evaluatePushRules(ctx, event, mem, roomSize)
|
actions, err := s.evaluatePushRules(ctx, event, mem, roomSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("s.evaluatePushRules: %w", err)
|
||||||
}
|
}
|
||||||
a, tweaks, err := pushrules.ActionsToTweaks(actions)
|
a, tweaks, err := pushrules.ActionsToTweaks(actions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("pushrules.ActionsToTweaks: %w", err)
|
||||||
}
|
}
|
||||||
// TODO: support coalescing.
|
// TODO: support coalescing.
|
||||||
if a != pushrules.NotifyAction && a != pushrules.CoalesceAction {
|
if a != pushrules.NotifyAction && a != pushrules.CoalesceAction {
|
||||||
|
@ -508,9 +508,9 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks)
|
devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, mem.Domain, tweaks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("s.localPushDevices: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
n := &api.Notification{
|
n := &api.Notification{
|
||||||
|
@ -527,18 +527,18 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
|
||||||
RoomID: event.RoomID(),
|
RoomID: event.RoomID(),
|
||||||
TS: gomatrixserverlib.AsTimestamp(time.Now()),
|
TS: gomatrixserverlib.AsTimestamp(time.Now()),
|
||||||
}
|
}
|
||||||
if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), streamPos, tweaks, n); err != nil {
|
if err = s.db.InsertNotification(ctx, mem.Localpart, mem.Domain, event.EventID(), streamPos, tweaks, n); err != nil {
|
||||||
return err
|
return fmt.Errorf("s.db.InsertNotification: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil {
|
if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil {
|
||||||
return err
|
return fmt.Errorf("s.syncProducer.GetAndSendNotificationData: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// We do this after InsertNotification. Thus, this should always return >=1.
|
// We do this after InsertNotification. Thus, this should always return >=1.
|
||||||
userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications)
|
userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, mem.Domain, tables.AllNotifications)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("s.db.GetNotificationCount: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
|
@ -589,7 +589,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(rejected) > 0 {
|
if len(rejected) > 0 {
|
||||||
s.deleteRejectedPushers(ctx, rejected, mem.Localpart)
|
s.deleteRejectedPushers(ctx, rejected, mem.Localpart, mem.Domain)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -606,7 +606,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get accountdata to check if the event.Sender() is ignored by mem.LocalPart
|
// Get accountdata to check if the event.Sender() is ignored by mem.LocalPart
|
||||||
data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, "", "m.ignored_user_list")
|
data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, mem.Domain, "", "m.ignored_user_list")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -621,7 +621,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
|
||||||
return nil, fmt.Errorf("user %s is ignored", sender)
|
return nil, fmt.Errorf("user %s is ignored", sender)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart)
|
ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart, mem.Domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -693,10 +693,10 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err
|
||||||
|
|
||||||
// localPushDevices pushes to the configured devices of a local
|
// localPushDevices pushes to the configured devices of a local
|
||||||
// user. The map keys are [url][format].
|
// user. The map keys are [url][format].
|
||||||
func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
|
func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
|
||||||
pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db)
|
pusherDevices, err := util.GetPushDevices(ctx, localpart, serverName, tweaks, s.db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", fmt.Errorf("util.GetPushDevices: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var profileTag string
|
var profileTag string
|
||||||
|
@ -791,7 +791,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatri
|
||||||
}
|
}
|
||||||
|
|
||||||
// deleteRejectedPushers deletes the pushers associated with the given devices.
|
// deleteRejectedPushers deletes the pushers associated with the given devices.
|
||||||
func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) {
|
func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string, serverName gomatrixserverlib.ServerName) {
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"localpart": localpart,
|
"localpart": localpart,
|
||||||
"app_id0": devices[0].AppID,
|
"app_id0": devices[0].AppID,
|
||||||
|
@ -799,7 +799,7 @@ func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, dev
|
||||||
}).Warnf("Deleting pushers rejected by the HTTP push gateway")
|
}).Warnf("Deleting pushers rejected by the HTTP push gateway")
|
||||||
|
|
||||||
for _, d := range devices {
|
for _, d := range devices {
|
||||||
if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil {
|
if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart, serverName); err != nil {
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"localpart": localpart,
|
"localpart": localpart,
|
||||||
}).WithError(err).Errorf("Unable to delete rejected pusher")
|
}).WithError(err).Errorf("Unable to delete rejected pusher")
|
||||||
|
|
|
@ -68,7 +68,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
|
||||||
if req.DataType == "" {
|
if req.DataType == "" {
|
||||||
return fmt.Errorf("data type must not be empty")
|
return fmt.Errorf("data type must not be empty")
|
||||||
}
|
}
|
||||||
if err := a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData); err != nil {
|
if err := a.DB.SaveAccountData(ctx, local, domain, req.RoomID, req.DataType, req.AccountData); err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Error("a.DB.SaveAccountData failed")
|
util.GetLogger(ctx).WithError(err).Error("a.DB.SaveAccountData failed")
|
||||||
return fmt.Errorf("failed to save account data: %w", err)
|
return fmt.Errorf("failed to save account data: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -108,7 +108,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now())))
|
deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, domain, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now())))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed")
|
logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed")
|
||||||
return err
|
return err
|
||||||
|
@ -124,7 +124,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, a.DB); err != nil {
|
if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, domain, a.DB); err != nil {
|
||||||
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed")
|
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -175,8 +175,10 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
||||||
if serverName == "" {
|
if serverName == "" {
|
||||||
serverName = a.Config.Matrix.ServerName
|
serverName = a.Config.Matrix.ServerName
|
||||||
}
|
}
|
||||||
// XXXX: Use the server name here
|
if !a.Config.Matrix.IsLocalServerName(serverName) {
|
||||||
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
|
return fmt.Errorf("server name %s is not local", serverName)
|
||||||
|
}
|
||||||
|
acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
||||||
switch req.OnConflict {
|
switch req.OnConflict {
|
||||||
|
@ -215,8 +217,8 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
|
if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, req.Localpart); err != nil {
|
||||||
return err
|
return fmt.Errorf("a.DB.SetDisplayName: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
postRegisterJoinRooms(a.Cfg, acc, a.RSAPI)
|
postRegisterJoinRooms(a.Cfg, acc, a.RSAPI)
|
||||||
|
@ -227,11 +229,14 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
|
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
|
||||||
if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
|
if !a.Config.Matrix.IsLocalServerName(req.ServerName) {
|
||||||
|
return fmt.Errorf("server name %s is not local", req.ServerName)
|
||||||
|
}
|
||||||
|
if err := a.DB.SetPassword(ctx, req.Localpart, req.ServerName, req.Password); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if req.LogoutDevices {
|
if req.LogoutDevices {
|
||||||
if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, ""); err != nil {
|
if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, req.ServerName, ""); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -244,14 +249,15 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
|
||||||
if serverName == "" {
|
if serverName == "" {
|
||||||
serverName = a.Config.Matrix.ServerName
|
serverName = a.Config.Matrix.ServerName
|
||||||
}
|
}
|
||||||
_ = serverName
|
if !a.Config.Matrix.IsLocalServerName(serverName) {
|
||||||
// XXXX: Use the server name here
|
return fmt.Errorf("server name %s is not local", serverName)
|
||||||
|
}
|
||||||
util.GetLogger(ctx).WithFields(logrus.Fields{
|
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||||
"localpart": req.Localpart,
|
"localpart": req.Localpart,
|
||||||
"device_id": req.DeviceID,
|
"device_id": req.DeviceID,
|
||||||
"display_name": req.DeviceDisplayName,
|
"display_name": req.DeviceDisplayName,
|
||||||
}).Info("PerformDeviceCreation")
|
}).Info("PerformDeviceCreation")
|
||||||
dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
|
dev, err := a.DB.CreateDevice(ctx, req.Localpart, serverName, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -276,12 +282,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
|
||||||
deletedDeviceIDs := req.DeviceIDs
|
deletedDeviceIDs := req.DeviceIDs
|
||||||
if len(req.DeviceIDs) == 0 {
|
if len(req.DeviceIDs) == 0 {
|
||||||
var devices []api.Device
|
var devices []api.Device
|
||||||
devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
|
devices, err = a.DB.RemoveAllDevices(ctx, local, domain, req.ExceptDeviceID)
|
||||||
for _, d := range devices {
|
for _, d := range devices {
|
||||||
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
|
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs)
|
err = a.DB.RemoveDevices(ctx, local, domain, req.DeviceIDs)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -335,23 +341,29 @@ func (a *UserInternalAPI) PerformLastSeenUpdate(
|
||||||
req *api.PerformLastSeenUpdateRequest,
|
req *api.PerformLastSeenUpdateRequest,
|
||||||
res *api.PerformLastSeenUpdateResponse,
|
res *api.PerformLastSeenUpdateResponse,
|
||||||
) error {
|
) error {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID)
|
localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
||||||
}
|
}
|
||||||
if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil {
|
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||||
|
return fmt.Errorf("server name %s is not local", domain)
|
||||||
|
}
|
||||||
|
if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, domain, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil {
|
||||||
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
|
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
|
func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
|
localpart, domain, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID)
|
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||||
|
return fmt.Errorf("server name %s is not local", domain)
|
||||||
|
}
|
||||||
|
dev, err := a.DB.GetDeviceByID(ctx, localpart, domain, req.DeviceID)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
res.DeviceExists = false
|
res.DeviceExists = false
|
||||||
return nil
|
return nil
|
||||||
|
@ -366,7 +378,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
|
err = a.DB.UpdateDevice(ctx, localpart, domain, req.DeviceID, req.DisplayName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
|
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
|
||||||
return err
|
return err
|
||||||
|
@ -406,7 +418,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
|
||||||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||||
return fmt.Errorf("cannot query profile of remote users (server name %s)", domain)
|
return fmt.Errorf("cannot query profile of remote users (server name %s)", domain)
|
||||||
}
|
}
|
||||||
prof, err := a.DB.GetProfileByLocalpart(ctx, local)
|
prof, err := a.DB.GetProfileByLocalpart(ctx, local, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil
|
return nil
|
||||||
|
@ -457,7 +469,7 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
|
||||||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||||
return fmt.Errorf("cannot query devices of remote users (server name %s)", domain)
|
return fmt.Errorf("cannot query devices of remote users (server name %s)", domain)
|
||||||
}
|
}
|
||||||
devs, err := a.DB.GetDevicesByLocalpart(ctx, local)
|
devs, err := a.DB.GetDevicesByLocalpart(ctx, local, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -476,7 +488,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
|
||||||
}
|
}
|
||||||
if req.DataType != "" {
|
if req.DataType != "" {
|
||||||
var data json.RawMessage
|
var data json.RawMessage
|
||||||
data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
|
data, err = a.DB.GetAccountDataByType(ctx, local, domain, req.RoomID, req.DataType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -494,7 +506,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
global, rooms, err := a.DB.GetAccountData(ctx, local)
|
global, rooms, err := a.DB.GetAccountData(ctx, local, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -527,7 +539,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
|
||||||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart)
|
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -561,14 +573,14 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
|
||||||
AccountType: api.AccountTypeAppService,
|
AccountType: api.AccountTypeAppService,
|
||||||
}
|
}
|
||||||
|
|
||||||
localpart, _, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix)
|
localpart, domain, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if localpart != "" { // AS is masquerading as another user
|
if localpart != "" { // AS is masquerading as another user
|
||||||
// Verify that the user is registered
|
// Verify that the user is registered
|
||||||
account, err := a.DB.GetAccountByLocalpart(ctx, localpart)
|
account, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain)
|
||||||
// Verify that the account exists and either appServiceID matches or
|
// Verify that the account exists and either appServiceID matches or
|
||||||
// it belongs to the appservice user namespaces
|
// it belongs to the appservice user namespaces
|
||||||
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
|
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
|
||||||
|
@ -620,7 +632,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err := a.DB.DeactivateAccount(ctx, req.Localpart)
|
err := a.DB.DeactivateAccount(ctx, req.Localpart, serverName)
|
||||||
res.AccountDeactivated = err == nil
|
res.AccountDeactivated = err == nil
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -783,7 +795,7 @@ func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.Query
|
||||||
if req.Only == "highlight" {
|
if req.Only == "highlight" {
|
||||||
filter = tables.HighlightNotifications
|
filter = tables.HighlightNotifications
|
||||||
}
|
}
|
||||||
notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, fromID, req.Limit, filter)
|
notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, req.ServerName, fromID, req.Limit, filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -811,23 +823,23 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if req.Pusher.Kind == "" {
|
if req.Pusher.Kind == "" {
|
||||||
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart)
|
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart, req.ServerName)
|
||||||
}
|
}
|
||||||
if req.Pusher.PushKeyTS == 0 {
|
if req.Pusher.PushKeyTS == 0 {
|
||||||
req.Pusher.PushKeyTS = int64(time.Now().Unix())
|
req.Pusher.PushKeyTS = int64(time.Now().Unix())
|
||||||
}
|
}
|
||||||
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart)
|
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart, req.ServerName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
|
func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
|
||||||
pushers, err := a.DB.GetPushers(ctx, req.Localpart)
|
pushers, err := a.DB.GetPushers(ctx, req.Localpart, req.ServerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for i := range pushers {
|
for i := range pushers {
|
||||||
logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID)
|
logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID)
|
||||||
if pushers[i].SessionID != req.SessionID {
|
if pushers[i].SessionID != req.SessionID {
|
||||||
err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart)
|
err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart, req.ServerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -838,7 +850,7 @@ func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.Pe
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
|
func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
|
||||||
var err error
|
var err error
|
||||||
res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart)
|
res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart, req.ServerName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -864,11 +876,11 @@ func (a *UserInternalAPI) PerformPushRulesPut(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
|
func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID)
|
localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to split user ID %q for push rules", req.UserID)
|
return fmt.Errorf("failed to split user ID %q for push rules", req.UserID)
|
||||||
}
|
}
|
||||||
pushRules, err := a.DB.QueryPushRules(ctx, localpart)
|
pushRules, err := a.DB.QueryPushRules(ctx, localpart, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to query push rules: %w", err)
|
return fmt.Errorf("failed to query push rules: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -877,14 +889,14 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
|
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
|
||||||
profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL)
|
profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.ServerName, req.AvatarURL)
|
||||||
res.Profile = profile
|
res.Profile = profile
|
||||||
res.Changed = changed
|
res.Changed = changed
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
|
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, req *api.QueryNumericLocalpartRequest, res *api.QueryNumericLocalpartResponse) error {
|
||||||
id, err := a.DB.GetNewNumericLocalpart(ctx)
|
id, err := a.DB.GetNewNumericLocalpart(ctx, req.ServerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -894,12 +906,12 @@ func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.Qu
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error {
|
func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error {
|
||||||
var err error
|
var err error
|
||||||
res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart)
|
res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart, req.ServerName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error {
|
func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error {
|
||||||
acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.PlaintextPassword)
|
acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.ServerName, req.PlaintextPassword)
|
||||||
switch err {
|
switch err {
|
||||||
case sql.ErrNoRows: // user does not exist
|
case sql.ErrNoRows: // user does not exist
|
||||||
return nil
|
return nil
|
||||||
|
@ -915,23 +927,24 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error {
|
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error {
|
||||||
profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName)
|
profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.ServerName, req.DisplayName)
|
||||||
res.Profile = profile
|
res.Profile = profile
|
||||||
res.Changed = changed
|
res.Changed = changed
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {
|
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {
|
||||||
localpart, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium)
|
localpart, domain, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
res.Localpart = localpart
|
res.Localpart = localpart
|
||||||
|
res.ServerName = domain
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error {
|
func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error {
|
||||||
r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart)
|
r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart, req.ServerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -944,7 +957,7 @@ func (a *UserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.Pe
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error {
|
func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error {
|
||||||
return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.Medium)
|
return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium)
|
||||||
}
|
}
|
||||||
|
|
||||||
const pushRulesAccountDataType = "m.push_rules"
|
const pushRulesAccountDataType = "m.push_rules"
|
||||||
|
|
|
@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
|
||||||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||||
return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain)
|
return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain)
|
||||||
}
|
}
|
||||||
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil {
|
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain); err != nil {
|
||||||
res.Data = nil
|
res.Data = nil
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -355,11 +355,12 @@ func (h *httpUserInternalAPI) SetAvatarURL(
|
||||||
|
|
||||||
func (h *httpUserInternalAPI) QueryNumericLocalpart(
|
func (h *httpUserInternalAPI) QueryNumericLocalpart(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
request *api.QueryNumericLocalpartRequest,
|
||||||
response *api.QueryNumericLocalpartResponse,
|
response *api.QueryNumericLocalpartResponse,
|
||||||
) error {
|
) error {
|
||||||
return httputil.CallInternalRPCAPI(
|
return httputil.CallInternalRPCAPI(
|
||||||
"QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath,
|
"QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath,
|
||||||
h.httpClient, ctx, &struct{}{}, response,
|
h.httpClient, ctx, request, response,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,12 +15,9 @@
|
||||||
package inthttp
|
package inthttp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/matrix-org/dendrite/internal/httputil"
|
"github.com/matrix-org/dendrite/internal/httputil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/util"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// nolint: gocyclo
|
// nolint: gocyclo
|
||||||
|
@ -152,15 +149,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
||||||
httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL),
|
httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL),
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: Look at the shape of this
|
internalAPIMux.Handle(
|
||||||
internalAPIMux.Handle(QueryNumericLocalpartPath,
|
QueryNumericLocalpartPath,
|
||||||
httputil.MakeInternalAPI("UserAPIQueryNumericLocalpart", func(req *http.Request) util.JSONResponse {
|
httputil.MakeInternalRPCAPI("UserAPIQueryNumericLocalpart", s.QueryNumericLocalpart),
|
||||||
response := api.QueryNumericLocalpartResponse{}
|
|
||||||
if err := s.QueryNumericLocalpart(req.Context(), &response); err != nil {
|
|
||||||
return util.ErrorResponse(err)
|
|
||||||
}
|
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
|
||||||
}),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
internalAPIMux.Handle(
|
internalAPIMux.Handle(
|
||||||
|
|
|
@ -61,12 +61,12 @@ func (p *SyncAPI) SendAccountData(userID string, data eventutil.AccountData) err
|
||||||
// GetAndSendNotificationData reads the database and sends data about unread
|
// GetAndSendNotificationData reads the database and sends data about unread
|
||||||
// notifications to the Sync API server.
|
// notifications to the Sync API server.
|
||||||
func (p *SyncAPI) GetAndSendNotificationData(ctx context.Context, userID, roomID string) error {
|
func (p *SyncAPI) GetAndSendNotificationData(ctx context.Context, userID, roomID string) error {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, roomID)
|
ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, domain, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,40 +29,40 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Profile interface {
|
type Profile interface {
|
||||||
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
GetProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error)
|
||||||
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
||||||
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
|
SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error)
|
||||||
SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, bool, error)
|
SetDisplayName(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Account interface {
|
type Account interface {
|
||||||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||||
// account already exists, it will return nil, ErrUserExists.
|
// account already exists, it will return nil, ErrUserExists.
|
||||||
CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
CreateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||||
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
|
GetAccountByPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) (*api.Account, error)
|
||||||
GetNewNumericLocalpart(ctx context.Context) (int64, error)
|
GetNewNumericLocalpart(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
|
||||||
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
|
CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error)
|
||||||
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error)
|
||||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error)
|
||||||
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
|
SetPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type AccountData interface {
|
type AccountData interface {
|
||||||
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
|
SaveAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error
|
||||||
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
|
GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
|
||||||
// GetAccountDataByType returns account data matching a given
|
// GetAccountDataByType returns account data matching a given
|
||||||
// localpart, room ID and type.
|
// localpart, room ID and type.
|
||||||
// If no account data could be found, returns nil
|
// If no account data could be found, returns nil
|
||||||
// Returns an error if there was an issue with the retrieval
|
// Returns an error if there was an issue with the retrieval
|
||||||
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
|
GetAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error)
|
||||||
QueryPushRules(ctx context.Context, localpart string) (*pushrules.AccountRuleSets, error)
|
QueryPushRules(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*pushrules.AccountRuleSets, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Device interface {
|
type Device interface {
|
||||||
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
|
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
|
||||||
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
|
GetDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error)
|
||||||
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
|
GetDevicesByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Device, error)
|
||||||
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
|
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
|
||||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||||
// If there is already a device with the same device ID for this user, that access token will be revoked
|
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||||
|
@ -70,12 +70,12 @@ type Device interface {
|
||||||
// an error will be returned.
|
// an error will be returned.
|
||||||
// If no device ID is given one is generated.
|
// If no device ID is given one is generated.
|
||||||
// Returns the device on success.
|
// Returns the device on success.
|
||||||
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
|
CreateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
|
||||||
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
|
UpdateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error
|
||||||
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error
|
UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error
|
||||||
RemoveDevices(ctx context.Context, localpart string, devices []string) error
|
RemoveDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error
|
||||||
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
|
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
|
||||||
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
|
RemoveAllDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) (devices []api.Device, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type KeyBackup interface {
|
type KeyBackup interface {
|
||||||
|
@ -107,26 +107,26 @@ type OpenID interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Pusher interface {
|
type Pusher interface {
|
||||||
UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error
|
UpsertPusher(ctx context.Context, p api.Pusher, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||||
GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error)
|
GetPushers(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error)
|
||||||
RemovePusher(ctx context.Context, appid, pushkey, localpart string) error
|
RemovePusher(ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||||
RemovePushers(ctx context.Context, appid, pushkey string) error
|
RemovePushers(ctx context.Context, appid, pushkey string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type ThreePID interface {
|
type ThreePID interface {
|
||||||
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
|
SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName gomatrixserverlib.ServerName, medium string) (err error)
|
||||||
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
|
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
|
||||||
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
|
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error)
|
||||||
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
|
GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Notification interface {
|
type Notification interface {
|
||||||
InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error
|
InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error
|
||||||
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error)
|
DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error)
|
||||||
SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, read bool) (affected bool, err error)
|
SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, read bool) (affected bool, err error)
|
||||||
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
|
GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
|
||||||
GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
|
GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error)
|
||||||
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)
|
GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error)
|
||||||
DeleteOldNotifications(ctx context.Context) error
|
DeleteOldNotifications(ctx context.Context) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
const accountDataSchema = `
|
const accountDataSchema = `
|
||||||
|
@ -29,27 +30,28 @@ const accountDataSchema = `
|
||||||
CREATE TABLE IF NOT EXISTS userapi_account_datas (
|
CREATE TABLE IF NOT EXISTS userapi_account_datas (
|
||||||
-- The Matrix user ID localpart for this account
|
-- The Matrix user ID localpart for this account
|
||||||
localpart TEXT NOT NULL,
|
localpart TEXT NOT NULL,
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
-- The room ID for this data (empty string if not specific to a room)
|
-- The room ID for this data (empty string if not specific to a room)
|
||||||
room_id TEXT,
|
room_id TEXT,
|
||||||
-- The account data type
|
-- The account data type
|
||||||
type TEXT NOT NULL,
|
type TEXT NOT NULL,
|
||||||
-- The account data content
|
-- The account data content
|
||||||
content TEXT NOT NULL,
|
content TEXT NOT NULL
|
||||||
|
|
||||||
PRIMARY KEY(localpart, room_id, type)
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS userapi_account_datas_idx ON userapi_account_datas(localpart, server_name, room_id, type);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertAccountDataSQL = `
|
const insertAccountDataSQL = `
|
||||||
INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
|
INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5)
|
||||||
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content
|
ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = EXCLUDED.content
|
||||||
`
|
`
|
||||||
|
|
||||||
const selectAccountDataSQL = "" +
|
const selectAccountDataSQL = "" +
|
||||||
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1"
|
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2"
|
||||||
|
|
||||||
const selectAccountDataByTypeSQL = "" +
|
const selectAccountDataByTypeSQL = "" +
|
||||||
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4"
|
||||||
|
|
||||||
type accountDataStatements struct {
|
type accountDataStatements struct {
|
||||||
insertAccountDataStmt *sql.Stmt
|
insertAccountDataStmt *sql.Stmt
|
||||||
|
@ -71,21 +73,24 @@ func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) InsertAccountData(
|
func (s *accountDataStatements) InsertAccountData(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
roomID, dataType string, content json.RawMessage,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
|
||||||
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
|
_, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) SelectAccountData(
|
func (s *accountDataStatements) SelectAccountData(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (
|
) (
|
||||||
/* global */ map[string]json.RawMessage,
|
/* global */ map[string]json.RawMessage,
|
||||||
/* rooms */ map[string]map[string]json.RawMessage,
|
/* rooms */ map[string]map[string]json.RawMessage,
|
||||||
error,
|
error,
|
||||||
) {
|
) {
|
||||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
|
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -117,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) SelectAccountDataByType(
|
func (s *accountDataStatements) SelectAccountDataByType(
|
||||||
ctx context.Context, localpart, roomID, dataType string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
roomID, dataType string,
|
||||||
) (data json.RawMessage, err error) {
|
) (data json.RawMessage, err error) {
|
||||||
var bytes []byte
|
var bytes []byte
|
||||||
stmt := s.selectAccountDataByTypeStmt
|
stmt := s.selectAccountDataByTypeStmt
|
||||||
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
|
if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@ package postgres
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
@ -34,7 +35,8 @@ const accountsSchema = `
|
||||||
-- Stores data about accounts.
|
-- Stores data about accounts.
|
||||||
CREATE TABLE IF NOT EXISTS userapi_accounts (
|
CREATE TABLE IF NOT EXISTS userapi_accounts (
|
||||||
-- The Matrix user ID localpart for this account
|
-- The Matrix user ID localpart for this account
|
||||||
localpart TEXT NOT NULL PRIMARY KEY,
|
localpart TEXT NOT NULL,
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
-- When this account was first created, as a unix timestamp (ms resolution).
|
-- When this account was first created, as a unix timestamp (ms resolution).
|
||||||
created_ts BIGINT NOT NULL,
|
created_ts BIGINT NOT NULL,
|
||||||
-- The password hash for this account. Can be NULL if this is a passwordless account.
|
-- The password hash for this account. Can be NULL if this is a passwordless account.
|
||||||
|
@ -48,25 +50,27 @@ CREATE TABLE IF NOT EXISTS userapi_accounts (
|
||||||
-- TODO:
|
-- TODO:
|
||||||
-- upgraded_ts, devices, any email reset stuff?
|
-- upgraded_ts, devices, any email reset stuff?
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS userapi_accounts_idx ON userapi_accounts(localpart, server_name);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertAccountSQL = "" +
|
const insertAccountSQL = "" +
|
||||||
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
|
"INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||||
|
|
||||||
const updatePasswordSQL = "" +
|
const updatePasswordSQL = "" +
|
||||||
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
|
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3"
|
||||||
|
|
||||||
const deactivateAccountSQL = "" +
|
const deactivateAccountSQL = "" +
|
||||||
"UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1"
|
"UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1 AND server_name = $2"
|
||||||
|
|
||||||
const selectAccountByLocalpartSQL = "" +
|
const selectAccountByLocalpartSQL = "" +
|
||||||
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
|
"SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
|
||||||
|
|
||||||
const selectPasswordHashSQL = "" +
|
const selectPasswordHashSQL = "" +
|
||||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
|
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE"
|
||||||
|
|
||||||
const selectNewNumericLocalpartSQL = "" +
|
const selectNewNumericLocalpartSQL = "" +
|
||||||
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$'"
|
"SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$' AND server_name = $1"
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
insertAccountStmt *sql.Stmt
|
insertAccountStmt *sql.Stmt
|
||||||
|
@ -117,59 +121,62 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam
|
||||||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||||
// on success.
|
// on success.
|
||||||
func (s *accountsStatements) InsertAccount(
|
func (s *accountsStatements) InsertAccount(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
hash, appserviceID string, accountType api.AccountType,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if accountType != api.AccountTypeAppService {
|
if accountType != api.AccountTypeAppService {
|
||||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
|
_, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType)
|
||||||
} else {
|
} else {
|
||||||
_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
|
_, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("insertAccountStmt: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &api.Account{
|
return &api.Account{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
UserID: userutil.MakeUserID(localpart, serverName),
|
||||||
ServerName: s.serverName,
|
ServerName: serverName,
|
||||||
AppServiceID: appserviceID,
|
AppServiceID: appserviceID,
|
||||||
AccountType: accountType,
|
AccountType: accountType,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) UpdatePassword(
|
func (s *accountsStatements) UpdatePassword(
|
||||||
ctx context.Context, localpart, passwordHash string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
passwordHash string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
|
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) DeactivateAccount(
|
func (s *accountsStatements) DeactivateAccount(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
|
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) SelectPasswordHash(
|
func (s *accountsStatements) SelectPasswordHash(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (hash string, err error) {
|
) (hash string, err error) {
|
||||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
|
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) SelectAccountByLocalpart(
|
func (s *accountsStatements) SelectAccountByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
var appserviceIDPtr sql.NullString
|
var appserviceIDPtr sql.NullString
|
||||||
var acc api.Account
|
var acc api.Account
|
||||||
|
|
||||||
stmt := s.selectAccountByLocalpartStmt
|
stmt := s.selectAccountByLocalpartStmt
|
||||||
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
|
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != sql.ErrNoRows {
|
if err != sql.ErrNoRows {
|
||||||
log.WithError(err).Error("Unable to retrieve user from the db")
|
log.WithError(err).Error("Unable to retrieve user from the db")
|
||||||
|
@ -180,19 +187,17 @@ func (s *accountsStatements) SelectAccountByLocalpart(
|
||||||
acc.AppServiceID = appserviceIDPtr.String
|
acc.AppServiceID = appserviceIDPtr.String
|
||||||
}
|
}
|
||||||
|
|
||||||
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
|
acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName)
|
||||||
acc.ServerName = s.serverName
|
|
||||||
|
|
||||||
return &acc, nil
|
return &acc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) SelectNewNumericLocalpart(
|
func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||||
) (id int64, err error) {
|
) (id int64, err error) {
|
||||||
stmt := s.selectNewNumericLocalpartStmt
|
stmt := s.selectNewNumericLocalpartStmt
|
||||||
if txn != nil {
|
if txn != nil {
|
||||||
stmt = sqlutil.TxStmt(txn, stmt)
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
}
|
}
|
||||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
err = stmt.QueryRowContext(ctx, serverName).Scan(&id)
|
||||||
return id + 1, err
|
return id + 1, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
package deltas
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
var serverNamesTables = []string{
|
||||||
|
"userapi_accounts",
|
||||||
|
"userapi_account_datas",
|
||||||
|
"userapi_devices",
|
||||||
|
"userapi_notifications",
|
||||||
|
"userapi_openid_tokens",
|
||||||
|
"userapi_profiles",
|
||||||
|
"userapi_pushers",
|
||||||
|
"userapi_threepids",
|
||||||
|
}
|
||||||
|
|
||||||
|
// These tables have a PRIMARY KEY constraint which we need to drop so
|
||||||
|
// that we can recreate a new unique index that contains the server name.
|
||||||
|
// If the new key doesn't exist (i.e. the database was created before the
|
||||||
|
// table rename migration) we'll try to drop the old one instead.
|
||||||
|
var serverNamesDropPK = map[string]string{
|
||||||
|
"userapi_accounts": "account_accounts",
|
||||||
|
"userapi_account_datas": "account_data",
|
||||||
|
"userapi_profiles": "account_profiles",
|
||||||
|
}
|
||||||
|
|
||||||
|
// These indices are out of date so let's drop them. They will get recreated
|
||||||
|
// automatically.
|
||||||
|
var serverNamesDropIndex = []string{
|
||||||
|
"userapi_pusher_localpart_idx",
|
||||||
|
"userapi_pusher_app_id_pushkey_localpart_idx",
|
||||||
|
}
|
||||||
|
|
||||||
|
// I know what you're thinking: you're wondering "why doesn't this use $1
|
||||||
|
// and pass variadic parameters to ExecContext?" — the answer is because
|
||||||
|
// PostgreSQL doesn't expect the table name to be specified as a substituted
|
||||||
|
// argument in that way so it results in a syntax error in the query.
|
||||||
|
|
||||||
|
func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
|
||||||
|
for _, table := range serverNamesTables {
|
||||||
|
q := fmt.Sprintf(
|
||||||
|
"ALTER TABLE IF EXISTS %s ADD COLUMN IF NOT EXISTS server_name TEXT NOT NULL DEFAULT '';",
|
||||||
|
pq.QuoteIdentifier(table),
|
||||||
|
)
|
||||||
|
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||||
|
return fmt.Errorf("add server name to %q error: %w", table, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for newTable, oldTable := range serverNamesDropPK {
|
||||||
|
q := fmt.Sprintf(
|
||||||
|
"ALTER TABLE IF EXISTS %s DROP CONSTRAINT IF EXISTS %s;",
|
||||||
|
pq.QuoteIdentifier(newTable), pq.QuoteIdentifier(newTable+"_pkey"),
|
||||||
|
)
|
||||||
|
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||||
|
return fmt.Errorf("drop new PK from %q error: %w", newTable, err)
|
||||||
|
}
|
||||||
|
q = fmt.Sprintf(
|
||||||
|
"ALTER TABLE IF EXISTS %s DROP CONSTRAINT IF EXISTS %s;",
|
||||||
|
pq.QuoteIdentifier(newTable), pq.QuoteIdentifier(oldTable+"_pkey"),
|
||||||
|
)
|
||||||
|
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||||
|
return fmt.Errorf("drop old PK from %q error: %w", newTable, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, index := range serverNamesDropIndex {
|
||||||
|
q := fmt.Sprintf(
|
||||||
|
"DROP INDEX IF EXISTS %s;",
|
||||||
|
pq.QuoteIdentifier(index),
|
||||||
|
)
|
||||||
|
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||||
|
return fmt.Errorf("drop index %q error: %w", index, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,28 @@
|
||||||
|
package deltas
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
// I know what you're thinking: you're wondering "why doesn't this use $1
|
||||||
|
// and pass variadic parameters to ExecContext?" — the answer is because
|
||||||
|
// PostgreSQL doesn't expect the table name to be specified as a substituted
|
||||||
|
// argument in that way so it results in a syntax error in the query.
|
||||||
|
|
||||||
|
func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
|
||||||
|
for _, table := range serverNamesTables {
|
||||||
|
q := fmt.Sprintf(
|
||||||
|
"UPDATE %s SET server_name = %s WHERE server_name = '';",
|
||||||
|
pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)),
|
||||||
|
)
|
||||||
|
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||||
|
return fmt.Errorf("write server names to %q error: %w", table, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -17,6 +17,7 @@ package postgres
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
@ -50,6 +51,7 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
|
||||||
-- as it is smaller, makes it clearer that we only manage devices for our own users, and may make
|
-- as it is smaller, makes it clearer that we only manage devices for our own users, and may make
|
||||||
-- migration to different domain names easier.
|
-- migration to different domain names easier.
|
||||||
localpart TEXT NOT NULL,
|
localpart TEXT NOT NULL,
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
-- When this devices was first recognised on the network, as a unix timestamp (ms resolution).
|
-- When this devices was first recognised on the network, as a unix timestamp (ms resolution).
|
||||||
created_ts BIGINT NOT NULL,
|
created_ts BIGINT NOT NULL,
|
||||||
-- The display name, human friendlier than device_id and updatable
|
-- The display name, human friendlier than device_id and updatable
|
||||||
|
@ -65,39 +67,39 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
|
||||||
);
|
);
|
||||||
|
|
||||||
-- Device IDs must be unique for a given user.
|
-- Device IDs must be unique for a given user.
|
||||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, device_id);
|
CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, server_name, device_id);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertDeviceSQL = "" +
|
const insertDeviceSQL = "" +
|
||||||
"INSERT INTO userapi_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" +
|
"INSERT INTO userapi_devices(device_id, localpart, server_name, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" +
|
||||||
" RETURNING session_id"
|
" RETURNING session_id"
|
||||||
|
|
||||||
const selectDeviceByTokenSQL = "" +
|
const selectDeviceByTokenSQL = "" +
|
||||||
"SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
|
"SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1"
|
||||||
|
|
||||||
const selectDeviceByIDSQL = "" +
|
const selectDeviceByIDSQL = "" +
|
||||||
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
|
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
|
||||||
|
|
||||||
const selectDevicesByLocalpartSQL = "" +
|
const selectDevicesByLocalpartSQL = "" +
|
||||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
|
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
|
||||||
|
|
||||||
const updateDeviceNameSQL = "" +
|
const updateDeviceNameSQL = "" +
|
||||||
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4"
|
||||||
|
|
||||||
const deleteDeviceSQL = "" +
|
const deleteDeviceSQL = "" +
|
||||||
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
|
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3"
|
||||||
|
|
||||||
const deleteDevicesByLocalpartSQL = "" +
|
const deleteDevicesByLocalpartSQL = "" +
|
||||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
|
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3"
|
||||||
|
|
||||||
const deleteDevicesSQL = "" +
|
const deleteDevicesSQL = "" +
|
||||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id = ANY($2)"
|
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = ANY($3)"
|
||||||
|
|
||||||
const selectDevicesByIDSQL = "" +
|
const selectDevicesByIDSQL = "" +
|
||||||
"SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
|
"SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
|
||||||
|
|
||||||
const updateDeviceLastSeen = "" +
|
const updateDeviceLastSeen = "" +
|
||||||
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
|
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
|
||||||
|
|
||||||
type devicesStatements struct {
|
type devicesStatements struct {
|
||||||
insertDeviceStmt *sql.Stmt
|
insertDeviceStmt *sql.Stmt
|
||||||
|
@ -148,18 +150,19 @@ func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName
|
||||||
// Returns an error if the user already has a device with the given device ID.
|
// Returns an error if the user already has a device with the given device ID.
|
||||||
// Returns the device on success.
|
// Returns the device on success.
|
||||||
func (s *devicesStatements) InsertDevice(
|
func (s *devicesStatements) InsertDevice(
|
||||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
ctx context.Context, txn *sql.Tx, id string,
|
||||||
displayName *string, ipAddr, userAgent string,
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
accessToken string, displayName *string, ipAddr, userAgent string,
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
var sessionID int64
|
var sessionID int64
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
|
||||||
if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
|
if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("insertDeviceStmt: %w", err)
|
||||||
}
|
}
|
||||||
return &api.Device{
|
return &api.Device{
|
||||||
ID: id,
|
ID: id,
|
||||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
UserID: userutil.MakeUserID(localpart, serverName),
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
SessionID: sessionID,
|
SessionID: sessionID,
|
||||||
LastSeenTS: createdTimeMS,
|
LastSeenTS: createdTimeMS,
|
||||||
|
@ -170,38 +173,45 @@ func (s *devicesStatements) InsertDevice(
|
||||||
|
|
||||||
// deleteDevice removes a single device by id and user localpart.
|
// deleteDevice removes a single device by id and user localpart.
|
||||||
func (s *devicesStatements) DeleteDevice(
|
func (s *devicesStatements) DeleteDevice(
|
||||||
ctx context.Context, txn *sql.Tx, id, localpart string,
|
ctx context.Context, txn *sql.Tx, id string,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
||||||
_, err := stmt.ExecContext(ctx, id, localpart)
|
_, err := stmt.ExecContext(ctx, id, localpart, serverName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// deleteDevices removes a single or multiple devices by ids and user localpart.
|
// deleteDevices removes a single or multiple devices by ids and user localpart.
|
||||||
// Returns an error if the execution failed.
|
// Returns an error if the execution failed.
|
||||||
func (s *devicesStatements) DeleteDevices(
|
func (s *devicesStatements) DeleteDevices(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
devices []string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt)
|
||||||
_, err := stmt.ExecContext(ctx, localpart, pq.Array(devices))
|
_, err := stmt.ExecContext(ctx, localpart, serverName, pq.Array(devices))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// deleteDevicesByLocalpart removes all devices for the
|
// deleteDevicesByLocalpart removes all devices for the
|
||||||
// given user localpart.
|
// given user localpart.
|
||||||
func (s *devicesStatements) DeleteDevicesByLocalpart(
|
func (s *devicesStatements) DeleteDevicesByLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
exceptDeviceID string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
||||||
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
|
_, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) UpdateDeviceName(
|
func (s *devicesStatements) UpdateDeviceName(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
deviceID string, displayName *string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
||||||
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
|
_, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -210,10 +220,11 @@ func (s *devicesStatements) SelectDeviceByToken(
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
var dev api.Device
|
var dev api.Device
|
||||||
var localpart string
|
var localpart string
|
||||||
|
var serverName gomatrixserverlib.ServerName
|
||||||
stmt := s.selectDeviceByTokenStmt
|
stmt := s.selectDeviceByTokenStmt
|
||||||
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
|
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||||
dev.AccessToken = accessToken
|
dev.AccessToken = accessToken
|
||||||
}
|
}
|
||||||
return &dev, err
|
return &dev, err
|
||||||
|
@ -222,16 +233,18 @@ func (s *devicesStatements) SelectDeviceByToken(
|
||||||
// selectDeviceByID retrieves a device from the database with the given user
|
// selectDeviceByID retrieves a device from the database with the given user
|
||||||
// localpart and deviceID
|
// localpart and deviceID
|
||||||
func (s *devicesStatements) SelectDeviceByID(
|
func (s *devicesStatements) SelectDeviceByID(
|
||||||
ctx context.Context, localpart, deviceID string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
deviceID string,
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
var dev api.Device
|
var dev api.Device
|
||||||
var displayName, ip sql.NullString
|
var displayName, ip sql.NullString
|
||||||
var lastseenTS sql.NullInt64
|
var lastseenTS sql.NullInt64
|
||||||
stmt := s.selectDeviceByIDStmt
|
stmt := s.selectDeviceByIDStmt
|
||||||
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip)
|
err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dev.ID = deviceID
|
dev.ID = deviceID
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||||
if displayName.Valid {
|
if displayName.Valid {
|
||||||
dev.DisplayName = displayName.String
|
dev.DisplayName = displayName.String
|
||||||
}
|
}
|
||||||
|
@ -254,10 +267,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
||||||
var devices []api.Device
|
var devices []api.Device
|
||||||
var dev api.Device
|
var dev api.Device
|
||||||
var localpart string
|
var localpart string
|
||||||
|
var serverName gomatrixserverlib.ServerName
|
||||||
var lastseents sql.NullInt64
|
var lastseents sql.NullInt64
|
||||||
var displayName sql.NullString
|
var displayName sql.NullString
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
|
if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if displayName.Valid {
|
if displayName.Valid {
|
||||||
|
@ -266,17 +280,19 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
||||||
if lastseents.Valid {
|
if lastseents.Valid {
|
||||||
dev.LastSeenTS = lastseents.Int64
|
dev.LastSeenTS = lastseents.Int64
|
||||||
}
|
}
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||||
devices = append(devices, dev)
|
devices = append(devices, dev)
|
||||||
}
|
}
|
||||||
return devices, rows.Err()
|
return devices, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) SelectDevicesByLocalpart(
|
func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
exceptDeviceID string,
|
||||||
) ([]api.Device, error) {
|
) ([]api.Device, error) {
|
||||||
devices := []api.Device{}
|
devices := []api.Device{}
|
||||||
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
|
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return devices, err
|
return devices, err
|
||||||
|
@ -307,16 +323,16 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||||
dev.UserAgent = useragent.String
|
dev.UserAgent = useragent.String
|
||||||
}
|
}
|
||||||
|
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||||
devices = append(devices, dev)
|
devices = append(devices, dev)
|
||||||
}
|
}
|
||||||
|
|
||||||
return devices, rows.Err()
|
return devices, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error {
|
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
|
||||||
lastSeenTs := time.Now().UnixNano() / 1000000
|
lastSeenTs := time.Now().UnixNano() / 1000000
|
||||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
||||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID)
|
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,6 +43,7 @@ const notificationSchema = `
|
||||||
CREATE TABLE IF NOT EXISTS userapi_notifications (
|
CREATE TABLE IF NOT EXISTS userapi_notifications (
|
||||||
id BIGSERIAL PRIMARY KEY,
|
id BIGSERIAL PRIMARY KEY,
|
||||||
localpart TEXT NOT NULL,
|
localpart TEXT NOT NULL,
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
room_id TEXT NOT NULL,
|
room_id TEXT NOT NULL,
|
||||||
event_id TEXT NOT NULL,
|
event_id TEXT NOT NULL,
|
||||||
stream_pos BIGINT NOT NULL,
|
stream_pos BIGINT NOT NULL,
|
||||||
|
@ -52,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications (
|
||||||
read BOOLEAN NOT NULL DEFAULT FALSE
|
read BOOLEAN NOT NULL DEFAULT FALSE
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
|
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id);
|
||||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
|
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id);
|
||||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
|
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertNotificationSQL = "" +
|
const insertNotificationSQL = "" +
|
||||||
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
"INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
|
||||||
|
|
||||||
const deleteNotificationsUpToSQL = "" +
|
const deleteNotificationsUpToSQL = "" +
|
||||||
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
|
"DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4"
|
||||||
|
|
||||||
const updateNotificationReadSQL = "" +
|
const updateNotificationReadSQL = "" +
|
||||||
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
|
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1"
|
||||||
|
|
||||||
const selectNotificationSQL = "" +
|
const selectNotificationSQL = "" +
|
||||||
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
|
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" +
|
||||||
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
|
"(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" +
|
||||||
") AND NOT read ORDER BY localpart, id LIMIT $4"
|
") AND NOT read ORDER BY localpart, id LIMIT $5"
|
||||||
|
|
||||||
const selectNotificationCountSQL = "" +
|
const selectNotificationCountSQL = "" +
|
||||||
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
|
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" +
|
||||||
"(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
|
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
|
||||||
") AND NOT read"
|
") AND NOT read"
|
||||||
|
|
||||||
const selectRoomNotificationCountsSQL = "" +
|
const selectRoomNotificationCountsSQL = "" +
|
||||||
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
|
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
|
||||||
"WHERE localpart = $1 AND room_id = $2 AND NOT read"
|
"WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read"
|
||||||
|
|
||||||
const cleanNotificationsSQL = "" +
|
const cleanNotificationsSQL = "" +
|
||||||
"DELETE FROM userapi_notifications WHERE" +
|
"DELETE FROM userapi_notifications WHERE" +
|
||||||
|
@ -111,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert inserts a notification into the database.
|
// Insert inserts a notification into the database.
|
||||||
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
|
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error {
|
||||||
roomID, tsMS := n.RoomID, n.TS
|
roomID, tsMS := n.RoomID, n.TS
|
||||||
nn := *n
|
nn := *n
|
||||||
// Clears out fields that have their own columns to (1) shrink the
|
// Clears out fields that have their own columns to (1) shrink the
|
||||||
|
@ -122,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
|
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteUpTo deletes all previous notifications, up to and including the event.
|
// DeleteUpTo deletes all previous notifications, up to and including the event.
|
||||||
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
|
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) {
|
||||||
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
|
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
@ -141,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRead updates the "read" value for an event.
|
// UpdateRead updates the "read" value for an event.
|
||||||
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
|
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) {
|
||||||
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
|
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
@ -154,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l
|
||||||
return nrows > 0, nil
|
return nrows > 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||||
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
|
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
|
@ -197,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
|
||||||
return notifs, maxID, rows.Err()
|
return notifs, maxID, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) {
|
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) {
|
||||||
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count)
|
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) {
|
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) {
|
||||||
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight)
|
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package postgres
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
@ -18,16 +19,17 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
|
||||||
token TEXT NOT NULL PRIMARY KEY,
|
token TEXT NOT NULL PRIMARY KEY,
|
||||||
-- The Matrix user ID for this account
|
-- The Matrix user ID for this account
|
||||||
localpart TEXT NOT NULL,
|
localpart TEXT NOT NULL,
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
-- When the token expires, as a unix timestamp (ms resolution).
|
-- When the token expires, as a unix timestamp (ms resolution).
|
||||||
token_expires_at_ms BIGINT NOT NULL
|
token_expires_at_ms BIGINT NOT NULL
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertOpenIDTokenSQL = "" +
|
const insertOpenIDTokenSQL = "" +
|
||||||
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
"INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
|
||||||
|
|
||||||
const selectOpenIDTokenSQL = "" +
|
const selectOpenIDTokenSQL = "" +
|
||||||
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
"SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
||||||
|
|
||||||
type openIDTokenStatements struct {
|
type openIDTokenStatements struct {
|
||||||
insertTokenStmt *sql.Stmt
|
insertTokenStmt *sql.Stmt
|
||||||
|
@ -54,11 +56,11 @@ func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
||||||
func (s *openIDTokenStatements) InsertOpenIDToken(
|
func (s *openIDTokenStatements) InsertOpenIDToken(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
txn *sql.Tx,
|
txn *sql.Tx,
|
||||||
token, localpart string,
|
token, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
expiresAtMS int64,
|
expiresAtMS int64,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
||||||
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
|
_, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,10 +71,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
|
||||||
token string,
|
token string,
|
||||||
) (*api.OpenIDTokenAttributes, error) {
|
) (*api.OpenIDTokenAttributes, error) {
|
||||||
var openIDTokenAttrs api.OpenIDTokenAttributes
|
var openIDTokenAttrs api.OpenIDTokenAttributes
|
||||||
|
var localpart string
|
||||||
|
var serverName gomatrixserverlib.ServerName
|
||||||
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
||||||
&openIDTokenAttrs.UserID,
|
&localpart, &serverName,
|
||||||
&openIDTokenAttrs.ExpiresAtMS,
|
&openIDTokenAttrs.ExpiresAtMS,
|
||||||
)
|
)
|
||||||
|
openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != sql.ErrNoRows {
|
if err != sql.ErrNoRows {
|
||||||
log.WithError(err).Error("Unable to retrieve token from the db")
|
log.WithError(err).Error("Unable to retrieve token from the db")
|
||||||
|
|
|
@ -23,42 +23,46 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
const profilesSchema = `
|
const profilesSchema = `
|
||||||
-- Stores data about accounts profiles.
|
-- Stores data about accounts profiles.
|
||||||
CREATE TABLE IF NOT EXISTS userapi_profiles (
|
CREATE TABLE IF NOT EXISTS userapi_profiles (
|
||||||
-- The Matrix user ID localpart for this account
|
-- The Matrix user ID localpart for this account
|
||||||
localpart TEXT NOT NULL PRIMARY KEY,
|
localpart TEXT NOT NULL,
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
-- The display name for this account
|
-- The display name for this account
|
||||||
display_name TEXT,
|
display_name TEXT,
|
||||||
-- The URL of the avatar for this account
|
-- The URL of the avatar for this account
|
||||||
avatar_url TEXT
|
avatar_url TEXT
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS userapi_profiles_idx ON userapi_profiles(localpart, server_name);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertProfileSQL = "" +
|
const insertProfileSQL = "" +
|
||||||
"INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
|
"INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)"
|
||||||
|
|
||||||
const selectProfileByLocalpartSQL = "" +
|
const selectProfileByLocalpartSQL = "" +
|
||||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
|
"SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2"
|
||||||
|
|
||||||
const setAvatarURLSQL = "" +
|
const setAvatarURLSQL = "" +
|
||||||
"UPDATE userapi_profiles AS new" +
|
"UPDATE userapi_profiles AS new" +
|
||||||
" SET avatar_url = $1" +
|
" SET avatar_url = $1" +
|
||||||
" FROM userapi_profiles AS old" +
|
" FROM userapi_profiles AS old" +
|
||||||
" WHERE new.localpart = $2" +
|
" WHERE new.localpart = $2 AND new.server_name = $3" +
|
||||||
" RETURNING new.display_name, old.avatar_url <> new.avatar_url"
|
" RETURNING new.display_name, old.avatar_url <> new.avatar_url"
|
||||||
|
|
||||||
const setDisplayNameSQL = "" +
|
const setDisplayNameSQL = "" +
|
||||||
"UPDATE userapi_profiles AS new" +
|
"UPDATE userapi_profiles AS new" +
|
||||||
" SET display_name = $1" +
|
" SET display_name = $1" +
|
||||||
" FROM userapi_profiles AS old" +
|
" FROM userapi_profiles AS old" +
|
||||||
" WHERE new.localpart = $2" +
|
" WHERE new.localpart = $2 AND new.server_name = $3" +
|
||||||
" RETURNING new.avatar_url, old.display_name <> new.display_name"
|
" RETURNING new.avatar_url, old.display_name <> new.display_name"
|
||||||
|
|
||||||
const selectProfilesBySearchSQL = "" +
|
const selectProfilesBySearchSQL = "" +
|
||||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
"SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||||
|
|
||||||
type profilesStatements struct {
|
type profilesStatements struct {
|
||||||
serverNoticesLocalpart string
|
serverNoticesLocalpart string
|
||||||
|
@ -87,18 +91,20 @@ func NewPostgresProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) InsertProfile(
|
func (s *profilesStatements) InsertProfile(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) SelectProfileByLocalpart(
|
func (s *profilesStatements) SelectProfileByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (*authtypes.Profile, error) {
|
) (*authtypes.Profile, error) {
|
||||||
var profile authtypes.Profile
|
var profile authtypes.Profile
|
||||||
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
|
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan(
|
||||||
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
|
&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -107,28 +113,34 @@ func (s *profilesStatements) SelectProfileByLocalpart(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) SetAvatarURL(
|
func (s *profilesStatements) SetAvatarURL(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
avatarURL string,
|
||||||
) (*authtypes.Profile, bool, error) {
|
) (*authtypes.Profile, bool, error) {
|
||||||
profile := &authtypes.Profile{
|
profile := &authtypes.Profile{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
AvatarURL: avatarURL,
|
ServerName: string(serverName),
|
||||||
|
AvatarURL: avatarURL,
|
||||||
}
|
}
|
||||||
var changed bool
|
var changed bool
|
||||||
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
||||||
err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName, &changed)
|
err := stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName, &changed)
|
||||||
return profile, changed, err
|
return profile, changed, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) SetDisplayName(
|
func (s *profilesStatements) SetDisplayName(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
displayName string,
|
||||||
) (*authtypes.Profile, bool, error) {
|
) (*authtypes.Profile, bool, error) {
|
||||||
profile := &authtypes.Profile{
|
profile := &authtypes.Profile{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
|
ServerName: string(serverName),
|
||||||
DisplayName: displayName,
|
DisplayName: displayName,
|
||||||
}
|
}
|
||||||
var changed bool
|
var changed bool
|
||||||
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
||||||
err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL, &changed)
|
err := stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL, &changed)
|
||||||
return profile, changed, err
|
return profile, changed, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -146,7 +158,7 @@ func (s *profilesStatements) SelectProfilesBySearch(
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var profile authtypes.Profile
|
var profile authtypes.Profile
|
||||||
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if profile.Localpart != s.serverNoticesLocalpart {
|
if profile.Localpart != s.serverNoticesLocalpart {
|
||||||
|
|
|
@ -25,6 +25,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
|
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
|
||||||
|
@ -33,6 +34,7 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
|
||||||
id BIGSERIAL PRIMARY KEY,
|
id BIGSERIAL PRIMARY KEY,
|
||||||
-- The Matrix user ID localpart for this pusher
|
-- The Matrix user ID localpart for this pusher
|
||||||
localpart TEXT NOT NULL,
|
localpart TEXT NOT NULL,
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
session_id BIGINT DEFAULT NULL,
|
session_id BIGINT DEFAULT NULL,
|
||||||
profile_tag TEXT,
|
profile_tag TEXT,
|
||||||
kind TEXT NOT NULL,
|
kind TEXT NOT NULL,
|
||||||
|
@ -49,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
|
||||||
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
|
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
|
||||||
|
|
||||||
-- For faster retrieving by localpart.
|
-- For faster retrieving by localpart.
|
||||||
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
|
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name);
|
||||||
|
|
||||||
-- Pushkey must be unique for a given user and app.
|
-- Pushkey must be unique for a given user and app.
|
||||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
|
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertPusherSQL = "" +
|
const insertPusherSQL = "" +
|
||||||
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
|
"INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
|
||||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
|
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" +
|
||||||
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
|
"ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12"
|
||||||
|
|
||||||
const selectPushersSQL = "" +
|
const selectPushersSQL = "" +
|
||||||
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
|
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2"
|
||||||
|
|
||||||
const deletePusherSQL = "" +
|
const deletePusherSQL = "" +
|
||||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
|
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4"
|
||||||
|
|
||||||
const deletePushersByAppIdAndPushKeySQL = "" +
|
const deletePushersByAppIdAndPushKeySQL = "" +
|
||||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
|
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
|
||||||
|
@ -95,18 +97,19 @@ type pushersStatements struct {
|
||||||
// Returns nil error success.
|
// Returns nil error success.
|
||||||
func (s *pushersStatements) InsertPusher(
|
func (s *pushersStatements) InsertPusher(
|
||||||
ctx context.Context, txn *sql.Tx, session_id int64,
|
ctx context.Context, txn *sql.Tx, session_id int64,
|
||||||
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||||
logrus.Debugf("Created pusher %d", session_id)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *pushersStatements) SelectPushers(
|
func (s *pushersStatements) SelectPushers(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) ([]api.Pusher, error) {
|
) ([]api.Pusher, error) {
|
||||||
pushers := []api.Pusher{}
|
pushers := []api.Pusher{}
|
||||||
rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart)
|
rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart, serverName)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return pushers, err
|
return pushers, err
|
||||||
|
@ -143,9 +146,10 @@ func (s *pushersStatements) SelectPushers(
|
||||||
|
|
||||||
// deletePusher removes a single pusher by pushkey and user localpart.
|
// deletePusher removes a single pusher by pushkey and user localpart.
|
||||||
func (s *pushersStatements) DeletePusher(
|
func (s *pushersStatements) DeletePusher(
|
||||||
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
|
ctx context.Context, txn *sql.Tx, appid, pushkey,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
|
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,8 @@
|
||||||
package postgres
|
package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -43,18 +45,24 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
||||||
Up: deltas.UpRenameTables,
|
Up: deltas.UpRenameTables,
|
||||||
Down: deltas.DownRenameTables,
|
Down: deltas.DownRenameTables,
|
||||||
})
|
})
|
||||||
|
m.AddMigrations(sqlutil.Migration{
|
||||||
|
Version: "userapi: server names",
|
||||||
|
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||||
|
return deltas.UpServerNames(ctx, txn, serverName)
|
||||||
|
},
|
||||||
|
})
|
||||||
if err = m.Up(base.Context()); err != nil {
|
if err = m.Up(base.Context()); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
accountDataTable, err := NewPostgresAccountDataTable(db)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
|
|
||||||
}
|
|
||||||
accountsTable, err := NewPostgresAccountsTable(db, serverName)
|
accountsTable, err := NewPostgresAccountsTable(db, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err)
|
return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err)
|
||||||
}
|
}
|
||||||
|
accountDataTable, err := NewPostgresAccountDataTable(db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
|
||||||
|
}
|
||||||
devicesTable, err := NewPostgresDevicesTable(db, serverName)
|
devicesTable, err := NewPostgresDevicesTable(db, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err)
|
return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err)
|
||||||
|
@ -95,6 +103,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("NewPostgresStatsTable: %w", err)
|
return nil, fmt.Errorf("NewPostgresStatsTable: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m = sqlutil.NewMigrator(db)
|
||||||
|
m.AddMigrations(sqlutil.Migration{
|
||||||
|
Version: "userapi: server names populate",
|
||||||
|
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||||
|
return deltas.UpServerNamesPopulate(ctx, txn, serverName)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err = m.Up(base.Context()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return &shared.Database{
|
return &shared.Database{
|
||||||
AccountDatas: accountDataTable,
|
AccountDatas: accountDataTable,
|
||||||
Accounts: accountsTable,
|
Accounts: accountsTable,
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
)
|
)
|
||||||
|
@ -33,21 +34,22 @@ CREATE TABLE IF NOT EXISTS userapi_threepids (
|
||||||
medium TEXT NOT NULL DEFAULT 'email',
|
medium TEXT NOT NULL DEFAULT 'email',
|
||||||
-- The localpart of the Matrix user ID associated to this 3PID
|
-- The localpart of the Matrix user ID associated to this 3PID
|
||||||
localpart TEXT NOT NULL,
|
localpart TEXT NOT NULL,
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
|
|
||||||
PRIMARY KEY(threepid, medium)
|
PRIMARY KEY(threepid, medium)
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart);
|
CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart, server_name);
|
||||||
`
|
`
|
||||||
|
|
||||||
const selectLocalpartForThreePIDSQL = "" +
|
const selectLocalpartForThreePIDSQL = "" +
|
||||||
"SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
"SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||||
|
|
||||||
const selectThreePIDsForLocalpartSQL = "" +
|
const selectThreePIDsForLocalpartSQL = "" +
|
||||||
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1"
|
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2"
|
||||||
|
|
||||||
const insertThreePIDSQL = "" +
|
const insertThreePIDSQL = "" +
|
||||||
"INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)"
|
"INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)"
|
||||||
|
|
||||||
const deleteThreePIDSQL = "" +
|
const deleteThreePIDSQL = "" +
|
||||||
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||||
|
@ -75,19 +77,20 @@ func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
|
||||||
|
|
||||||
func (s *threepidStatements) SelectLocalpartForThreePID(
|
func (s *threepidStatements) SelectLocalpartForThreePID(
|
||||||
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
||||||
) (localpart string, err error) {
|
) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
||||||
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
|
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return "", nil
|
return "", "", nil
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (threepids []authtypes.ThreePID, err error) {
|
) (threepids []authtypes.ThreePID, err error) {
|
||||||
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
|
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -109,10 +112,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) InsertThreePID(
|
func (s *threepidStatements) InsertThreePID(
|
||||||
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
|
ctx context.Context, txn *sql.Tx, threepid, medium,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
||||||
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
|
_, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -68,9 +68,10 @@ const (
|
||||||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||||
func (d *Database) GetAccountByPassword(
|
func (d *Database) GetAccountByPassword(
|
||||||
ctx context.Context, localpart, plaintextPassword string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
plaintextPassword string,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart)
|
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -80,24 +81,27 @@ func (d *Database) GetAccountByPassword(
|
||||||
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
|
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return d.Accounts.SelectAccountByLocalpart(ctx, localpart)
|
return d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
||||||
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
||||||
func (d *Database) GetProfileByLocalpart(
|
func (d *Database) GetProfileByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (*authtypes.Profile, error) {
|
) (*authtypes.Profile, error) {
|
||||||
return d.Profiles.SelectProfileByLocalpart(ctx, localpart)
|
return d.Profiles.SelectProfileByLocalpart(ctx, localpart, serverName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetAvatarURL updates the avatar URL of the profile associated with the given
|
// SetAvatarURL updates the avatar URL of the profile associated with the given
|
||||||
// localpart. Returns an error if something went wrong with the SQL query
|
// localpart. Returns an error if something went wrong with the SQL query
|
||||||
func (d *Database) SetAvatarURL(
|
func (d *Database) SetAvatarURL(
|
||||||
ctx context.Context, localpart string, avatarURL string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
avatarURL string,
|
||||||
) (profile *authtypes.Profile, changed bool, err error) {
|
) (profile *authtypes.Profile, changed bool, err error) {
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
|
profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, serverName, avatarURL)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
@ -106,10 +110,12 @@ func (d *Database) SetAvatarURL(
|
||||||
// SetDisplayName updates the display name of the profile associated with the given
|
// SetDisplayName updates the display name of the profile associated with the given
|
||||||
// localpart. Returns an error if something went wrong with the SQL query
|
// localpart. Returns an error if something went wrong with the SQL query
|
||||||
func (d *Database) SetDisplayName(
|
func (d *Database) SetDisplayName(
|
||||||
ctx context.Context, localpart string, displayName string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
displayName string,
|
||||||
) (profile *authtypes.Profile, changed bool, err error) {
|
) (profile *authtypes.Profile, changed bool, err error) {
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
|
profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, serverName, displayName)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
@ -117,14 +123,15 @@ func (d *Database) SetDisplayName(
|
||||||
|
|
||||||
// SetPassword sets the account password to the given hash.
|
// SetPassword sets the account password to the given hash.
|
||||||
func (d *Database) SetPassword(
|
func (d *Database) SetPassword(
|
||||||
ctx context.Context, localpart, plaintextPassword string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
plaintextPassword string,
|
||||||
) error {
|
) error {
|
||||||
hash, err := d.hashPassword(plaintextPassword)
|
hash, err := d.hashPassword(plaintextPassword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
|
||||||
return d.Accounts.UpdatePassword(ctx, localpart, hash)
|
return d.Accounts.UpdatePassword(ctx, localpart, serverName, hash)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -132,21 +139,22 @@ func (d *Database) SetPassword(
|
||||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||||
// account already exists, it will return nil, ErrUserExists.
|
// account already exists, it will return nil, ErrUserExists.
|
||||||
func (d *Database) CreateAccount(
|
func (d *Database) CreateAccount(
|
||||||
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||||
) (acc *api.Account, err error) {
|
) (acc *api.Account, err error) {
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
// For guest accounts, we create a new numeric local part
|
// For guest accounts, we create a new numeric local part
|
||||||
if accountType == api.AccountTypeGuest {
|
if accountType == api.AccountTypeGuest {
|
||||||
var numLocalpart int64
|
var numLocalpart int64
|
||||||
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn)
|
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("d.Accounts.SelectNewNumericLocalpart: %w", err)
|
||||||
}
|
}
|
||||||
localpart = strconv.FormatInt(numLocalpart, 10)
|
localpart = strconv.FormatInt(numLocalpart, 10)
|
||||||
plaintextPassword = ""
|
plaintextPassword = ""
|
||||||
appserviceID = ""
|
appserviceID = ""
|
||||||
}
|
}
|
||||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType)
|
acc, err = d.createAccount(ctx, txn, localpart, serverName, plaintextPassword, appserviceID, accountType)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
@ -155,7 +163,9 @@ func (d *Database) CreateAccount(
|
||||||
// WARNING! This function assumes that the relevant mutexes have already
|
// WARNING! This function assumes that the relevant mutexes have already
|
||||||
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
|
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
|
||||||
func (d *Database) createAccount(
|
func (d *Database) createAccount(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
var err error
|
var err error
|
||||||
var account *api.Account
|
var account *api.Account
|
||||||
|
@ -167,28 +177,28 @@ func (d *Database) createAccount(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
|
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, serverName, hash, appserviceID, accountType); err != nil {
|
||||||
return nil, sqlutil.ErrUserExists
|
return nil, sqlutil.ErrUserExists
|
||||||
}
|
}
|
||||||
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
|
if err = d.Profiles.InsertProfile(ctx, txn, localpart, serverName); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("d.Profiles.InsertProfile: %w", err)
|
||||||
}
|
}
|
||||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
|
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
|
||||||
prbs, err := json.Marshal(pushRuleSets)
|
prbs, err := json.Marshal(pushRuleSets)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("json.Marshal: %w", err)
|
||||||
}
|
}
|
||||||
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
|
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("d.AccountDatas.InsertAccountData: %w", err)
|
||||||
}
|
}
|
||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) QueryPushRules(
|
func (d *Database) QueryPushRules(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
localpart string,
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (*pushrules.AccountRuleSets, error) {
|
) (*pushrules.AccountRuleSets, error) {
|
||||||
data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, "", "m.push_rules")
|
data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, serverName, "", "m.push_rules")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -196,13 +206,13 @@ func (d *Database) QueryPushRules(
|
||||||
// If we didn't find any default push rules then we should just generate some
|
// If we didn't find any default push rules then we should just generate some
|
||||||
// fresh ones.
|
// fresh ones.
|
||||||
if len(data) == 0 {
|
if len(data) == 0 {
|
||||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
|
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
|
||||||
prbs, err := json.Marshal(pushRuleSets)
|
prbs, err := json.Marshal(pushRuleSets)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to marshal default push rules: %w", err)
|
return nil, fmt.Errorf("failed to marshal default push rules: %w", err)
|
||||||
}
|
}
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", prbs); dbErr != nil {
|
if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", prbs); dbErr != nil {
|
||||||
return fmt.Errorf("failed to save default push rules: %w", dbErr)
|
return fmt.Errorf("failed to save default push rules: %w", dbErr)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -225,22 +235,23 @@ func (d *Database) QueryPushRules(
|
||||||
// update the corresponding row with the new content
|
// update the corresponding row with the new content
|
||||||
// Returns a SQL error if there was an issue with the insertion/update
|
// Returns a SQL error if there was an issue with the insertion/update
|
||||||
func (d *Database) SaveAccountData(
|
func (d *Database) SaveAccountData(
|
||||||
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
roomID, dataType string, content json.RawMessage,
|
||||||
) error {
|
) error {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.AccountDatas.InsertAccountData(ctx, txn, localpart, roomID, dataType, content)
|
return d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, roomID, dataType, content)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountData returns account data related to a given localpart
|
// GetAccountData returns account data related to a given localpart
|
||||||
// If no account data could be found, returns an empty arrays
|
// If no account data could be found, returns an empty arrays
|
||||||
// Returns an error if there was an issue with the retrieval
|
// Returns an error if there was an issue with the retrieval
|
||||||
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
func (d *Database) GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (
|
||||||
global map[string]json.RawMessage,
|
global map[string]json.RawMessage,
|
||||||
rooms map[string]map[string]json.RawMessage,
|
rooms map[string]map[string]json.RawMessage,
|
||||||
err error,
|
err error,
|
||||||
) {
|
) {
|
||||||
return d.AccountDatas.SelectAccountData(ctx, localpart)
|
return d.AccountDatas.SelectAccountData(ctx, localpart, serverName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountDataByType returns account data matching a given
|
// GetAccountDataByType returns account data matching a given
|
||||||
|
@ -248,18 +259,19 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
||||||
// If no account data could be found, returns nil
|
// If no account data could be found, returns nil
|
||||||
// Returns an error if there was an issue with the retrieval
|
// Returns an error if there was an issue with the retrieval
|
||||||
func (d *Database) GetAccountDataByType(
|
func (d *Database) GetAccountDataByType(
|
||||||
ctx context.Context, localpart, roomID, dataType string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
roomID, dataType string,
|
||||||
) (data json.RawMessage, err error) {
|
) (data json.RawMessage, err error) {
|
||||||
return d.AccountDatas.SelectAccountDataByType(
|
return d.AccountDatas.SelectAccountDataByType(
|
||||||
ctx, localpart, roomID, dataType,
|
ctx, localpart, serverName, roomID, dataType,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
|
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
|
||||||
func (d *Database) GetNewNumericLocalpart(
|
func (d *Database) GetNewNumericLocalpart(
|
||||||
ctx context.Context,
|
ctx context.Context, serverName gomatrixserverlib.ServerName,
|
||||||
) (int64, error) {
|
) (int64, error) {
|
||||||
return d.Accounts.SelectNewNumericLocalpart(ctx, nil)
|
return d.Accounts.SelectNewNumericLocalpart(ctx, nil, serverName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
|
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
|
||||||
|
@ -276,10 +288,12 @@ var Err3PIDInUse = errors.New("this third-party identifier is already in use")
|
||||||
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
|
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
func (d *Database) SaveThreePIDAssociation(
|
func (d *Database) SaveThreePIDAssociation(
|
||||||
ctx context.Context, threepid, localpart, medium string,
|
ctx context.Context, threepid string,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
medium string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
user, err := d.ThreePIDs.SelectLocalpartForThreePID(
|
user, _, err := d.ThreePIDs.SelectLocalpartForThreePID(
|
||||||
ctx, txn, threepid, medium,
|
ctx, txn, threepid, medium,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -290,7 +304,7 @@ func (d *Database) SaveThreePIDAssociation(
|
||||||
return Err3PIDInUse
|
return Err3PIDInUse
|
||||||
}
|
}
|
||||||
|
|
||||||
return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart)
|
return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart, serverName)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -313,7 +327,7 @@ func (d *Database) RemoveThreePIDAssociation(
|
||||||
// Returns an error if there was a problem talking to the database.
|
// Returns an error if there was a problem talking to the database.
|
||||||
func (d *Database) GetLocalpartForThreePID(
|
func (d *Database) GetLocalpartForThreePID(
|
||||||
ctx context.Context, threepid string, medium string,
|
ctx context.Context, threepid string, medium string,
|
||||||
) (localpart string, err error) {
|
) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
|
||||||
return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium)
|
return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -322,16 +336,17 @@ func (d *Database) GetLocalpartForThreePID(
|
||||||
// If no association is known for this user, returns an empty slice.
|
// If no association is known for this user, returns an empty slice.
|
||||||
// Returns an error if there was an issue talking to the database.
|
// Returns an error if there was an issue talking to the database.
|
||||||
func (d *Database) GetThreePIDsForLocalpart(
|
func (d *Database) GetThreePIDsForLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (threepids []authtypes.ThreePID, err error) {
|
) (threepids []authtypes.ThreePID, err error) {
|
||||||
return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart)
|
return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckAccountAvailability checks if the username/localpart is already present
|
// CheckAccountAvailability checks if the username/localpart is already present
|
||||||
// in the database.
|
// in the database.
|
||||||
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
|
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
|
||||||
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
|
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) {
|
||||||
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart)
|
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
@ -341,12 +356,12 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin
|
||||||
// GetAccountByLocalpart returns the account associated with the given localpart.
|
// GetAccountByLocalpart returns the account associated with the given localpart.
|
||||||
// This function assumes the request is authenticated or the account data is used only internally.
|
// This function assumes the request is authenticated or the account data is used only internally.
|
||||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||||
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
|
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
// try to get the account with lowercase localpart (majority)
|
// try to get the account with lowercase localpart (majority)
|
||||||
acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart))
|
acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart), serverName)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart) // try with localpart as passed by the request
|
acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) // try with localpart as passed by the request
|
||||||
}
|
}
|
||||||
return acc, err
|
return acc, err
|
||||||
}
|
}
|
||||||
|
@ -359,20 +374,24 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
|
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
|
||||||
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
|
func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) {
|
||||||
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
|
||||||
return d.Accounts.DeactivateAccount(ctx, localpart)
|
return d.Accounts.DeactivateAccount(ctx, localpart, serverName)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateOpenIDToken persists a new token that was issued for OpenID Connect
|
// CreateOpenIDToken persists a new token that was issued for OpenID Connect
|
||||||
func (d *Database) CreateOpenIDToken(
|
func (d *Database) CreateOpenIDToken(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
token, localpart string,
|
token, userID string,
|
||||||
) (int64, error) {
|
) (int64, error) {
|
||||||
|
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS
|
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS
|
||||||
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS)
|
return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, domain, expiresAtMS)
|
||||||
})
|
})
|
||||||
return expiresAtMS, err
|
return expiresAtMS, err
|
||||||
}
|
}
|
||||||
|
@ -539,16 +558,19 @@ func (d *Database) GetDeviceByAccessToken(
|
||||||
// GetDeviceByID returns the device matching the given ID.
|
// GetDeviceByID returns the device matching the given ID.
|
||||||
// Returns sql.ErrNoRows if no matching device was found.
|
// Returns sql.ErrNoRows if no matching device was found.
|
||||||
func (d *Database) GetDeviceByID(
|
func (d *Database) GetDeviceByID(
|
||||||
ctx context.Context, localpart, deviceID string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
deviceID string,
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
return d.Devices.SelectDeviceByID(ctx, localpart, deviceID)
|
return d.Devices.SelectDeviceByID(ctx, localpart, serverName, deviceID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
||||||
func (d *Database) GetDevicesByLocalpart(
|
func (d *Database) GetDevicesByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) ([]api.Device, error) {
|
) ([]api.Device, error) {
|
||||||
return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, "")
|
return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, serverName, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||||
|
@ -562,18 +584,18 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap
|
||||||
// If no device ID is given one is generated.
|
// If no device ID is given one is generated.
|
||||||
// Returns the device on success.
|
// Returns the device on success.
|
||||||
func (d *Database) CreateDevice(
|
func (d *Database) CreateDevice(
|
||||||
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
displayName *string, ipAddr, userAgent string,
|
deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string,
|
||||||
) (dev *api.Device, returnErr error) {
|
) (dev *api.Device, returnErr error) {
|
||||||
if deviceID != nil {
|
if deviceID != nil {
|
||||||
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
var err error
|
var err error
|
||||||
// Revoke existing tokens for this device
|
// Revoke existing tokens for this device
|
||||||
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
|
@ -588,7 +610,7 @@ func (d *Database) CreateDevice(
|
||||||
|
|
||||||
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
var err error
|
var err error
|
||||||
dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
if returnErr == nil {
|
if returnErr == nil {
|
||||||
|
@ -614,10 +636,12 @@ func generateDeviceID() (string, error) {
|
||||||
// UpdateDevice updates the given device with the display name.
|
// UpdateDevice updates the given device with the display name.
|
||||||
// Returns SQL error if there are problems and nil on success.
|
// Returns SQL error if there are problems and nil on success.
|
||||||
func (d *Database) UpdateDevice(
|
func (d *Database) UpdateDevice(
|
||||||
ctx context.Context, localpart, deviceID string, displayName *string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
deviceID string, displayName *string,
|
||||||
) error {
|
) error {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.Devices.UpdateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
return d.Devices.UpdateDeviceName(ctx, txn, localpart, serverName, deviceID, displayName)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -626,10 +650,12 @@ func (d *Database) UpdateDevice(
|
||||||
// If the devices don't exist, it will not return an error
|
// If the devices don't exist, it will not return an error
|
||||||
// If something went wrong during the deletion, it will return the SQL error.
|
// If something went wrong during the deletion, it will return the SQL error.
|
||||||
func (d *Database) RemoveDevices(
|
func (d *Database) RemoveDevices(
|
||||||
ctx context.Context, localpart string, devices []string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
devices []string,
|
||||||
) error {
|
) error {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
if err := d.Devices.DeleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
|
if err := d.Devices.DeleteDevices(ctx, txn, localpart, serverName, devices); err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -640,14 +666,16 @@ func (d *Database) RemoveDevices(
|
||||||
// database matching the given user ID localpart.
|
// database matching the given user ID localpart.
|
||||||
// If something went wrong during the deletion, it will return the SQL error.
|
// If something went wrong during the deletion, it will return the SQL error.
|
||||||
func (d *Database) RemoveAllDevices(
|
func (d *Database) RemoveAllDevices(
|
||||||
ctx context.Context, localpart, exceptDeviceID string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
exceptDeviceID string,
|
||||||
) (devices []api.Device, err error) {
|
) (devices []api.Device, err error) {
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
|
devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
|
if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID); err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -656,9 +684,9 @@ func (d *Database) RemoveAllDevices(
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateDeviceLastSeen updates a last seen timestamp and the ip address.
|
// UpdateDeviceLastSeen updates a last seen timestamp and the ip address.
|
||||||
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error {
|
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr, userAgent)
|
return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, serverName, deviceID, ipAddr, userAgent)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -706,38 +734,38 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (
|
||||||
return d.LoginTokens.SelectLoginToken(ctx, token)
|
return d.LoginTokens.SelectLoginToken(ctx, token)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error {
|
func (d *Database) InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
|
return d.Notifications.Insert(ctx, txn, localpart, serverName, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) {
|
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error) {
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos)
|
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, serverName, roomID, pos)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, b bool) (affected bool, err error) {
|
func (d *Database) SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, b bool) (affected bool, err error) {
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b)
|
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, serverName, roomID, pos, b)
|
||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
func (d *Database) GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||||
return d.Notifications.Select(ctx, nil, localpart, fromID, limit, filter)
|
return d.Notifications.Select(ctx, nil, localpart, serverName, fromID, limit, filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) {
|
func (d *Database) GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error) {
|
||||||
return d.Notifications.SelectCount(ctx, nil, localpart, filter)
|
return d.Notifications.SelectCount(ctx, nil, localpart, serverName, filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) {
|
func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) {
|
||||||
return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID)
|
return d.Notifications.SelectRoomCounts(ctx, nil, localpart, serverName, roomID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) DeleteOldNotifications(ctx context.Context) error {
|
func (d *Database) DeleteOldNotifications(ctx context.Context) error {
|
||||||
|
@ -747,7 +775,8 @@ func (d *Database) DeleteOldNotifications(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) UpsertPusher(
|
func (d *Database) UpsertPusher(
|
||||||
ctx context.Context, p api.Pusher, localpart string,
|
ctx context.Context, p api.Pusher,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
data, err := json.Marshal(p.Data)
|
data, err := json.Marshal(p.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -766,25 +795,26 @@ func (d *Database) UpsertPusher(
|
||||||
p.ProfileTag,
|
p.ProfileTag,
|
||||||
p.Language,
|
p.Language,
|
||||||
string(data),
|
string(data),
|
||||||
localpart)
|
localpart,
|
||||||
|
serverName)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPushers returns the pushers matching the given localpart.
|
// GetPushers returns the pushers matching the given localpart.
|
||||||
func (d *Database) GetPushers(
|
func (d *Database) GetPushers(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) ([]api.Pusher, error) {
|
) ([]api.Pusher, error) {
|
||||||
return d.Pushers.SelectPushers(ctx, nil, localpart)
|
return d.Pushers.SelectPushers(ctx, nil, localpart, serverName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemovePusher deletes one pusher
|
// RemovePusher deletes one pusher
|
||||||
// Invoked when `append` is true and `kind` is null in
|
// Invoked when `append` is true and `kind` is null in
|
||||||
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
|
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
|
||||||
func (d *Database) RemovePusher(
|
func (d *Database) RemovePusher(
|
||||||
ctx context.Context, appid, pushkey, localpart string,
|
ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart)
|
err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart, serverName)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
const accountDataSchema = `
|
const accountDataSchema = `
|
||||||
|
@ -28,27 +29,28 @@ const accountDataSchema = `
|
||||||
CREATE TABLE IF NOT EXISTS userapi_account_datas (
|
CREATE TABLE IF NOT EXISTS userapi_account_datas (
|
||||||
-- The Matrix user ID localpart for this account
|
-- The Matrix user ID localpart for this account
|
||||||
localpart TEXT NOT NULL,
|
localpart TEXT NOT NULL,
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
-- The room ID for this data (empty string if not specific to a room)
|
-- The room ID for this data (empty string if not specific to a room)
|
||||||
room_id TEXT,
|
room_id TEXT,
|
||||||
-- The account data type
|
-- The account data type
|
||||||
type TEXT NOT NULL,
|
type TEXT NOT NULL,
|
||||||
-- The account data content
|
-- The account data content
|
||||||
content TEXT NOT NULL,
|
content TEXT NOT NULL
|
||||||
|
|
||||||
PRIMARY KEY(localpart, room_id, type)
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS userapi_account_datas_idx ON userapi_account_datas(localpart, server_name, room_id, type);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertAccountDataSQL = `
|
const insertAccountDataSQL = `
|
||||||
INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
|
INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5)
|
||||||
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
|
ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = $5
|
||||||
`
|
`
|
||||||
|
|
||||||
const selectAccountDataSQL = "" +
|
const selectAccountDataSQL = "" +
|
||||||
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1"
|
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2"
|
||||||
|
|
||||||
const selectAccountDataByTypeSQL = "" +
|
const selectAccountDataByTypeSQL = "" +
|
||||||
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4"
|
||||||
|
|
||||||
type accountDataStatements struct {
|
type accountDataStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
@ -73,20 +75,23 @@ func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) InsertAccountData(
|
func (s *accountDataStatements) InsertAccountData(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
roomID, dataType string, content json.RawMessage,
|
||||||
) error {
|
) error {
|
||||||
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
|
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, serverName, roomID, dataType, content)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) SelectAccountData(
|
func (s *accountDataStatements) SelectAccountData(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (
|
) (
|
||||||
/* global */ map[string]json.RawMessage,
|
/* global */ map[string]json.RawMessage,
|
||||||
/* rooms */ map[string]map[string]json.RawMessage,
|
/* rooms */ map[string]map[string]json.RawMessage,
|
||||||
error,
|
error,
|
||||||
) {
|
) {
|
||||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
|
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -117,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) SelectAccountDataByType(
|
func (s *accountDataStatements) SelectAccountDataByType(
|
||||||
ctx context.Context, localpart, roomID, dataType string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
roomID, dataType string,
|
||||||
) (data json.RawMessage, err error) {
|
) (data json.RawMessage, err error) {
|
||||||
var bytes []byte
|
var bytes []byte
|
||||||
stmt := s.selectAccountDataByTypeStmt
|
stmt := s.selectAccountDataByTypeStmt
|
||||||
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
|
if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,7 +34,8 @@ const accountsSchema = `
|
||||||
-- Stores data about accounts.
|
-- Stores data about accounts.
|
||||||
CREATE TABLE IF NOT EXISTS userapi_accounts (
|
CREATE TABLE IF NOT EXISTS userapi_accounts (
|
||||||
-- The Matrix user ID localpart for this account
|
-- The Matrix user ID localpart for this account
|
||||||
localpart TEXT NOT NULL PRIMARY KEY,
|
localpart TEXT NOT NULL,
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
-- When this account was first created, as a unix timestamp (ms resolution).
|
-- When this account was first created, as a unix timestamp (ms resolution).
|
||||||
created_ts BIGINT NOT NULL,
|
created_ts BIGINT NOT NULL,
|
||||||
-- The password hash for this account. Can be NULL if this is a passwordless account.
|
-- The password hash for this account. Can be NULL if this is a passwordless account.
|
||||||
|
@ -48,25 +49,27 @@ CREATE TABLE IF NOT EXISTS userapi_accounts (
|
||||||
-- TODO:
|
-- TODO:
|
||||||
-- upgraded_ts, devices, any email reset stuff?
|
-- upgraded_ts, devices, any email reset stuff?
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS userapi_accounts_idx ON userapi_accounts(localpart, server_name);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertAccountSQL = "" +
|
const insertAccountSQL = "" +
|
||||||
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
|
"INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||||
|
|
||||||
const updatePasswordSQL = "" +
|
const updatePasswordSQL = "" +
|
||||||
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
|
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3"
|
||||||
|
|
||||||
const deactivateAccountSQL = "" +
|
const deactivateAccountSQL = "" +
|
||||||
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1"
|
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1 AND server_name = $2"
|
||||||
|
|
||||||
const selectAccountByLocalpartSQL = "" +
|
const selectAccountByLocalpartSQL = "" +
|
||||||
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
|
"SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
|
||||||
|
|
||||||
const selectPasswordHashSQL = "" +
|
const selectPasswordHashSQL = "" +
|
||||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = 0"
|
||||||
|
|
||||||
const selectNewNumericLocalpartSQL = "" +
|
const selectNewNumericLocalpartSQL = "" +
|
||||||
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0"
|
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0 AND server_name = $1"
|
||||||
|
|
||||||
type accountsStatements struct {
|
type accountsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
@ -119,16 +122,17 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
||||||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||||
// on success.
|
// on success.
|
||||||
func (s *accountsStatements) InsertAccount(
|
func (s *accountsStatements) InsertAccount(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
|
ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
hash, appserviceID string, accountType api.AccountType,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
stmt := s.insertAccountStmt
|
stmt := s.insertAccountStmt
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if accountType != api.AccountTypeAppService {
|
if accountType != api.AccountTypeAppService {
|
||||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
|
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType)
|
||||||
} else {
|
} else {
|
||||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
|
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -136,42 +140,43 @@ func (s *accountsStatements) InsertAccount(
|
||||||
|
|
||||||
return &api.Account{
|
return &api.Account{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
UserID: userutil.MakeUserID(localpart, serverName),
|
||||||
ServerName: s.serverName,
|
ServerName: serverName,
|
||||||
AppServiceID: appserviceID,
|
AppServiceID: appserviceID,
|
||||||
AccountType: accountType,
|
AccountType: accountType,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) UpdatePassword(
|
func (s *accountsStatements) UpdatePassword(
|
||||||
ctx context.Context, localpart, passwordHash string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
passwordHash string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
|
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) DeactivateAccount(
|
func (s *accountsStatements) DeactivateAccount(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
|
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) SelectPasswordHash(
|
func (s *accountsStatements) SelectPasswordHash(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (hash string, err error) {
|
) (hash string, err error) {
|
||||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
|
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) SelectAccountByLocalpart(
|
func (s *accountsStatements) SelectAccountByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
var appserviceIDPtr sql.NullString
|
var appserviceIDPtr sql.NullString
|
||||||
var acc api.Account
|
var acc api.Account
|
||||||
|
|
||||||
stmt := s.selectAccountByLocalpartStmt
|
stmt := s.selectAccountByLocalpartStmt
|
||||||
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
|
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != sql.ErrNoRows {
|
if err != sql.ErrNoRows {
|
||||||
log.WithError(err).Error("Unable to retrieve user from the db")
|
log.WithError(err).Error("Unable to retrieve user from the db")
|
||||||
|
@ -182,20 +187,18 @@ func (s *accountsStatements) SelectAccountByLocalpart(
|
||||||
acc.AppServiceID = appserviceIDPtr.String
|
acc.AppServiceID = appserviceIDPtr.String
|
||||||
}
|
}
|
||||||
|
|
||||||
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
|
acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName)
|
||||||
acc.ServerName = s.serverName
|
|
||||||
|
|
||||||
return &acc, nil
|
return &acc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) SelectNewNumericLocalpart(
|
func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||||
) (id int64, err error) {
|
) (id int64, err error) {
|
||||||
stmt := s.selectNewNumericLocalpartStmt
|
stmt := s.selectNewNumericLocalpartStmt
|
||||||
if txn != nil {
|
if txn != nil {
|
||||||
stmt = sqlutil.TxStmt(txn, stmt)
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
}
|
}
|
||||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
err = stmt.QueryRowContext(ctx, serverName).Scan(&id)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return 1, nil
|
return 1, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error {
|
||||||
ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
|
ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
|
||||||
CREATE TABLE userapi_accounts (
|
CREATE TABLE userapi_accounts (
|
||||||
localpart TEXT NOT NULL PRIMARY KEY,
|
localpart TEXT NOT NULL PRIMARY KEY,
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
created_ts BIGINT NOT NULL,
|
created_ts BIGINT NOT NULL,
|
||||||
password_hash TEXT,
|
password_hash TEXT,
|
||||||
appservice_id TEXT,
|
appservice_id TEXT,
|
||||||
|
|
|
@ -14,6 +14,7 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
|
||||||
session_id INTEGER,
|
session_id INTEGER,
|
||||||
device_id TEXT ,
|
device_id TEXT ,
|
||||||
localpart TEXT ,
|
localpart TEXT ,
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
created_ts BIGINT,
|
created_ts BIGINT,
|
||||||
display_name TEXT,
|
display_name TEXT,
|
||||||
last_seen_ts BIGINT,
|
last_seen_ts BIGINT,
|
||||||
|
|
|
@ -12,6 +12,7 @@ func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
|
||||||
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
|
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
|
||||||
CREATE TABLE userapi_accounts (
|
CREATE TABLE userapi_accounts (
|
||||||
localpart TEXT NOT NULL PRIMARY KEY,
|
localpart TEXT NOT NULL PRIMARY KEY,
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
created_ts BIGINT NOT NULL,
|
created_ts BIGINT NOT NULL,
|
||||||
password_hash TEXT,
|
password_hash TEXT,
|
||||||
appservice_id TEXT,
|
appservice_id TEXT,
|
||||||
|
|
108
userapi/storage/sqlite3/deltas/2022110411000000_server_names.go
Normal file
108
userapi/storage/sqlite3/deltas/2022110411000000_server_names.go
Normal file
|
@ -0,0 +1,108 @@
|
||||||
|
package deltas
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
var serverNamesTables = []string{
|
||||||
|
"userapi_accounts",
|
||||||
|
"userapi_account_datas",
|
||||||
|
"userapi_devices",
|
||||||
|
"userapi_notifications",
|
||||||
|
"userapi_openid_tokens",
|
||||||
|
"userapi_profiles",
|
||||||
|
"userapi_pushers",
|
||||||
|
"userapi_threepids",
|
||||||
|
}
|
||||||
|
|
||||||
|
// These tables have a PRIMARY KEY constraint which we need to drop so
|
||||||
|
// that we can recreate a new unique index that contains the server name.
|
||||||
|
var serverNamesDropPK = []string{
|
||||||
|
"userapi_accounts",
|
||||||
|
"userapi_account_datas",
|
||||||
|
"userapi_profiles",
|
||||||
|
}
|
||||||
|
|
||||||
|
// These indices are out of date so let's drop them. They will get recreated
|
||||||
|
// automatically.
|
||||||
|
var serverNamesDropIndex = []string{
|
||||||
|
"userapi_pusher_localpart_idx",
|
||||||
|
"userapi_pusher_app_id_pushkey_localpart_idx",
|
||||||
|
}
|
||||||
|
|
||||||
|
// I know what you're thinking: you're wondering "why doesn't this use $1
|
||||||
|
// and pass variadic parameters to ExecContext?" — the answer is because
|
||||||
|
// PostgreSQL doesn't expect the table name to be specified as a substituted
|
||||||
|
// argument in that way so it results in a syntax error in the query.
|
||||||
|
|
||||||
|
func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
|
||||||
|
for _, table := range serverNamesTables {
|
||||||
|
q := fmt.Sprintf(
|
||||||
|
"SELECT COUNT(name) FROM sqlite_schema WHERE type='table' AND name=%s;",
|
||||||
|
pq.QuoteIdentifier(table),
|
||||||
|
)
|
||||||
|
var c int
|
||||||
|
if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
q = fmt.Sprintf(
|
||||||
|
"SELECT COUNT(*) FROM pragma_table_info(%s) WHERE name='server_name'",
|
||||||
|
pq.QuoteIdentifier(table),
|
||||||
|
)
|
||||||
|
if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 1 {
|
||||||
|
logrus.Infof("Table %s already has column, skipping", table)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == 0 {
|
||||||
|
q = fmt.Sprintf(
|
||||||
|
"ALTER TABLE %s ADD COLUMN server_name TEXT NOT NULL DEFAULT '';",
|
||||||
|
pq.QuoteIdentifier(table),
|
||||||
|
)
|
||||||
|
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||||
|
return fmt.Errorf("add server name to %q error: %w", table, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, table := range serverNamesDropPK {
|
||||||
|
q := fmt.Sprintf(
|
||||||
|
"SELECT COUNT(name), sql FROM sqlite_schema WHERE type='table' AND name=%s;",
|
||||||
|
pq.QuoteIdentifier(table),
|
||||||
|
)
|
||||||
|
var c int
|
||||||
|
var sql string
|
||||||
|
if err := tx.QueryRowContext(ctx, q).Scan(&c, &sql); err != nil || c == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
q = fmt.Sprintf(`
|
||||||
|
%s; -- create temporary table
|
||||||
|
INSERT INTO %s SELECT * FROM %s; -- copy data
|
||||||
|
DROP TABLE %s; -- drop original table
|
||||||
|
ALTER TABLE %s RENAME TO %s; -- rename new table
|
||||||
|
`,
|
||||||
|
strings.Replace(sql, table, table+"_tmp", 1), // create temporary table
|
||||||
|
table+"_tmp", table, // copy data
|
||||||
|
table, // drop original table
|
||||||
|
table+"_tmp", table, // rename new table
|
||||||
|
)
|
||||||
|
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||||
|
return fmt.Errorf("drop PK from %q error: %w", table, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, index := range serverNamesDropIndex {
|
||||||
|
q := fmt.Sprintf(
|
||||||
|
"DROP INDEX IF EXISTS %s;",
|
||||||
|
pq.QuoteIdentifier(index),
|
||||||
|
)
|
||||||
|
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||||
|
return fmt.Errorf("drop index %q error: %w", index, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,28 @@
|
||||||
|
package deltas
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
// I know what you're thinking: you're wondering "why doesn't this use $1
|
||||||
|
// and pass variadic parameters to ExecContext?" — the answer is because
|
||||||
|
// PostgreSQL doesn't expect the table name to be specified as a substituted
|
||||||
|
// argument in that way so it results in a syntax error in the query.
|
||||||
|
|
||||||
|
func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
|
||||||
|
for _, table := range serverNamesTables {
|
||||||
|
q := fmt.Sprintf(
|
||||||
|
"UPDATE %s SET server_name = %s WHERE server_name = '';",
|
||||||
|
pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)),
|
||||||
|
)
|
||||||
|
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||||
|
return fmt.Errorf("write server names to %q error: %w", table, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -40,49 +40,50 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
|
||||||
session_id INTEGER,
|
session_id INTEGER,
|
||||||
device_id TEXT ,
|
device_id TEXT ,
|
||||||
localpart TEXT ,
|
localpart TEXT ,
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
created_ts BIGINT,
|
created_ts BIGINT,
|
||||||
display_name TEXT,
|
display_name TEXT,
|
||||||
last_seen_ts BIGINT,
|
last_seen_ts BIGINT,
|
||||||
ip TEXT,
|
ip TEXT,
|
||||||
user_agent TEXT,
|
user_agent TEXT,
|
||||||
|
|
||||||
UNIQUE (localpart, device_id)
|
UNIQUE (localpart, server_name, device_id)
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertDeviceSQL = "" +
|
const insertDeviceSQL = "" +
|
||||||
"INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
|
"INSERT INTO userapi_devices (device_id, localpart, server_name, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
|
||||||
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"
|
||||||
|
|
||||||
const selectDevicesCountSQL = "" +
|
const selectDevicesCountSQL = "" +
|
||||||
"SELECT COUNT(access_token) FROM userapi_devices"
|
"SELECT COUNT(access_token) FROM userapi_devices"
|
||||||
|
|
||||||
const selectDeviceByTokenSQL = "" +
|
const selectDeviceByTokenSQL = "" +
|
||||||
"SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
|
"SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1"
|
||||||
|
|
||||||
const selectDeviceByIDSQL = "" +
|
const selectDeviceByIDSQL = "" +
|
||||||
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
|
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
|
||||||
|
|
||||||
const selectDevicesByLocalpartSQL = "" +
|
const selectDevicesByLocalpartSQL = "" +
|
||||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
|
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
|
||||||
|
|
||||||
const updateDeviceNameSQL = "" +
|
const updateDeviceNameSQL = "" +
|
||||||
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4"
|
||||||
|
|
||||||
const deleteDeviceSQL = "" +
|
const deleteDeviceSQL = "" +
|
||||||
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
|
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3"
|
||||||
|
|
||||||
const deleteDevicesByLocalpartSQL = "" +
|
const deleteDevicesByLocalpartSQL = "" +
|
||||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
|
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3"
|
||||||
|
|
||||||
const deleteDevicesSQL = "" +
|
const deleteDevicesSQL = "" +
|
||||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id IN ($2)"
|
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id IN ($3)"
|
||||||
|
|
||||||
const selectDevicesByIDSQL = "" +
|
const selectDevicesByIDSQL = "" +
|
||||||
"SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
|
"SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
|
||||||
|
|
||||||
const updateDeviceLastSeen = "" +
|
const updateDeviceLastSeen = "" +
|
||||||
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
|
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
|
||||||
|
|
||||||
type devicesStatements struct {
|
type devicesStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
@ -135,8 +136,9 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
||||||
// Returns an error if the user already has a device with the given device ID.
|
// Returns an error if the user already has a device with the given device ID.
|
||||||
// Returns the device on success.
|
// Returns the device on success.
|
||||||
func (s *devicesStatements) InsertDevice(
|
func (s *devicesStatements) InsertDevice(
|
||||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
ctx context.Context, txn *sql.Tx, id string,
|
||||||
displayName *string, ipAddr, userAgent string,
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
accessToken string, displayName *string, ipAddr, userAgent string,
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
var sessionID int64
|
var sessionID int64
|
||||||
|
@ -146,12 +148,12 @@ func (s *devicesStatements) InsertDevice(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
sessionID++
|
sessionID++
|
||||||
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
|
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &api.Device{
|
return &api.Device{
|
||||||
ID: id,
|
ID: id,
|
||||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
UserID: userutil.MakeUserID(localpart, serverName),
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
SessionID: sessionID,
|
SessionID: sessionID,
|
||||||
LastSeenTS: createdTimeMS,
|
LastSeenTS: createdTimeMS,
|
||||||
|
@ -161,44 +163,52 @@ func (s *devicesStatements) InsertDevice(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) DeleteDevice(
|
func (s *devicesStatements) DeleteDevice(
|
||||||
ctx context.Context, txn *sql.Tx, id, localpart string,
|
ctx context.Context, txn *sql.Tx, id string,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
||||||
_, err := stmt.ExecContext(ctx, id, localpart)
|
_, err := stmt.ExecContext(ctx, id, localpart, serverName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) DeleteDevices(
|
func (s *devicesStatements) DeleteDevices(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
devices []string,
|
||||||
) error {
|
) error {
|
||||||
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
|
orig := strings.Replace(deleteDevicesSQL, "($3)", sqlutil.QueryVariadicOffset(len(devices), 2), 1)
|
||||||
prep, err := s.db.Prepare(orig)
|
prep, err := s.db.Prepare(orig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
stmt := sqlutil.TxStmt(txn, prep)
|
stmt := sqlutil.TxStmt(txn, prep)
|
||||||
params := make([]interface{}, len(devices)+1)
|
params := make([]interface{}, len(devices)+2)
|
||||||
params[0] = localpart
|
params[0] = localpart
|
||||||
|
params[1] = serverName
|
||||||
for i, v := range devices {
|
for i, v := range devices {
|
||||||
params[i+1] = v
|
params[i+2] = v
|
||||||
}
|
}
|
||||||
_, err = stmt.ExecContext(ctx, params...)
|
_, err = stmt.ExecContext(ctx, params...)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) DeleteDevicesByLocalpart(
|
func (s *devicesStatements) DeleteDevicesByLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
exceptDeviceID string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
||||||
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
|
_, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) UpdateDeviceName(
|
func (s *devicesStatements) UpdateDeviceName(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
deviceID string, displayName *string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
||||||
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
|
_, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -207,10 +217,11 @@ func (s *devicesStatements) SelectDeviceByToken(
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
var dev api.Device
|
var dev api.Device
|
||||||
var localpart string
|
var localpart string
|
||||||
|
var serverName gomatrixserverlib.ServerName
|
||||||
stmt := s.selectDeviceByTokenStmt
|
stmt := s.selectDeviceByTokenStmt
|
||||||
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
|
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||||
dev.AccessToken = accessToken
|
dev.AccessToken = accessToken
|
||||||
}
|
}
|
||||||
return &dev, err
|
return &dev, err
|
||||||
|
@ -219,16 +230,18 @@ func (s *devicesStatements) SelectDeviceByToken(
|
||||||
// selectDeviceByID retrieves a device from the database with the given user
|
// selectDeviceByID retrieves a device from the database with the given user
|
||||||
// localpart and deviceID
|
// localpart and deviceID
|
||||||
func (s *devicesStatements) SelectDeviceByID(
|
func (s *devicesStatements) SelectDeviceByID(
|
||||||
ctx context.Context, localpart, deviceID string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
deviceID string,
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
var dev api.Device
|
var dev api.Device
|
||||||
var displayName, ip sql.NullString
|
var displayName, ip sql.NullString
|
||||||
stmt := s.selectDeviceByIDStmt
|
stmt := s.selectDeviceByIDStmt
|
||||||
var lastseenTS sql.NullInt64
|
var lastseenTS sql.NullInt64
|
||||||
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip)
|
err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dev.ID = deviceID
|
dev.ID = deviceID
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||||
if displayName.Valid {
|
if displayName.Valid {
|
||||||
dev.DisplayName = displayName.String
|
dev.DisplayName = displayName.String
|
||||||
}
|
}
|
||||||
|
@ -243,10 +256,12 @@ func (s *devicesStatements) SelectDeviceByID(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) SelectDevicesByLocalpart(
|
func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
exceptDeviceID string,
|
||||||
) ([]api.Device, error) {
|
) ([]api.Device, error) {
|
||||||
devices := []api.Device{}
|
devices := []api.Device{}
|
||||||
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
|
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return devices, err
|
return devices, err
|
||||||
|
@ -276,7 +291,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||||
dev.UserAgent = useragent.String
|
dev.UserAgent = useragent.String
|
||||||
}
|
}
|
||||||
|
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||||
devices = append(devices, dev)
|
devices = append(devices, dev)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -298,10 +313,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
||||||
var devices []api.Device
|
var devices []api.Device
|
||||||
var dev api.Device
|
var dev api.Device
|
||||||
var localpart string
|
var localpart string
|
||||||
|
var serverName gomatrixserverlib.ServerName
|
||||||
var displayName sql.NullString
|
var displayName sql.NullString
|
||||||
var lastseents sql.NullInt64
|
var lastseents sql.NullInt64
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
|
if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if displayName.Valid {
|
if displayName.Valid {
|
||||||
|
@ -310,15 +326,15 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
||||||
if lastseents.Valid {
|
if lastseents.Valid {
|
||||||
dev.LastSeenTS = lastseents.Int64
|
dev.LastSeenTS = lastseents.Int64
|
||||||
}
|
}
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||||
devices = append(devices, dev)
|
devices = append(devices, dev)
|
||||||
}
|
}
|
||||||
return devices, rows.Err()
|
return devices, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error {
|
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
|
||||||
lastSeenTs := time.Now().UnixNano() / 1000000
|
lastSeenTs := time.Now().UnixNano() / 1000000
|
||||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
||||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID)
|
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,6 +43,7 @@ const notificationSchema = `
|
||||||
CREATE TABLE IF NOT EXISTS userapi_notifications (
|
CREATE TABLE IF NOT EXISTS userapi_notifications (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
localpart TEXT NOT NULL,
|
localpart TEXT NOT NULL,
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
room_id TEXT NOT NULL,
|
room_id TEXT NOT NULL,
|
||||||
event_id TEXT NOT NULL,
|
event_id TEXT NOT NULL,
|
||||||
stream_pos BIGINT NOT NULL,
|
stream_pos BIGINT NOT NULL,
|
||||||
|
@ -52,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications (
|
||||||
read BOOLEAN NOT NULL DEFAULT FALSE
|
read BOOLEAN NOT NULL DEFAULT FALSE
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
|
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id);
|
||||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
|
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id);
|
||||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
|
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertNotificationSQL = "" +
|
const insertNotificationSQL = "" +
|
||||||
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
"INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
|
||||||
|
|
||||||
const deleteNotificationsUpToSQL = "" +
|
const deleteNotificationsUpToSQL = "" +
|
||||||
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
|
"DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4"
|
||||||
|
|
||||||
const updateNotificationReadSQL = "" +
|
const updateNotificationReadSQL = "" +
|
||||||
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
|
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1"
|
||||||
|
|
||||||
const selectNotificationSQL = "" +
|
const selectNotificationSQL = "" +
|
||||||
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
|
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" +
|
||||||
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
|
"(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" +
|
||||||
") AND NOT read ORDER BY localpart, id LIMIT $4"
|
") AND NOT read ORDER BY localpart, id LIMIT $5"
|
||||||
|
|
||||||
const selectNotificationCountSQL = "" +
|
const selectNotificationCountSQL = "" +
|
||||||
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
|
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" +
|
||||||
"(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
|
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
|
||||||
") AND NOT read"
|
") AND NOT read"
|
||||||
|
|
||||||
const selectRoomNotificationCountsSQL = "" +
|
const selectRoomNotificationCountsSQL = "" +
|
||||||
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
|
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
|
||||||
"WHERE localpart = $1 AND room_id = $2 AND NOT read"
|
"WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read"
|
||||||
|
|
||||||
const cleanNotificationsSQL = "" +
|
const cleanNotificationsSQL = "" +
|
||||||
"DELETE FROM userapi_notifications WHERE" +
|
"DELETE FROM userapi_notifications WHERE" +
|
||||||
|
@ -111,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert inserts a notification into the database.
|
// Insert inserts a notification into the database.
|
||||||
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
|
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error {
|
||||||
roomID, tsMS := n.RoomID, n.TS
|
roomID, tsMS := n.RoomID, n.TS
|
||||||
nn := *n
|
nn := *n
|
||||||
// Clears out fields that have their own columns to (1) shrink the
|
// Clears out fields that have their own columns to (1) shrink the
|
||||||
|
@ -122,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
|
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteUpTo deletes all previous notifications, up to and including the event.
|
// DeleteUpTo deletes all previous notifications, up to and including the event.
|
||||||
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
|
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) {
|
||||||
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
|
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
@ -141,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRead updates the "read" value for an event.
|
// UpdateRead updates the "read" value for an event.
|
||||||
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
|
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) {
|
||||||
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
|
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
@ -154,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l
|
||||||
return nrows > 0, nil
|
return nrows > 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||||
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
|
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
|
@ -197,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
|
||||||
return notifs, maxID, rows.Err()
|
return notifs, maxID, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) {
|
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) {
|
||||||
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count)
|
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) {
|
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) {
|
||||||
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight)
|
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package sqlite3
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
@ -18,16 +19,17 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
|
||||||
token TEXT NOT NULL PRIMARY KEY,
|
token TEXT NOT NULL PRIMARY KEY,
|
||||||
-- The Matrix user ID for this account
|
-- The Matrix user ID for this account
|
||||||
localpart TEXT NOT NULL,
|
localpart TEXT NOT NULL,
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
-- When the token expires, as a unix timestamp (ms resolution).
|
-- When the token expires, as a unix timestamp (ms resolution).
|
||||||
token_expires_at_ms BIGINT NOT NULL
|
token_expires_at_ms BIGINT NOT NULL
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertOpenIDTokenSQL = "" +
|
const insertOpenIDTokenSQL = "" +
|
||||||
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
"INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
|
||||||
|
|
||||||
const selectOpenIDTokenSQL = "" +
|
const selectOpenIDTokenSQL = "" +
|
||||||
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
"SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
||||||
|
|
||||||
type openIDTokenStatements struct {
|
type openIDTokenStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
@ -56,11 +58,11 @@ func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (
|
||||||
func (s *openIDTokenStatements) InsertOpenIDToken(
|
func (s *openIDTokenStatements) InsertOpenIDToken(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
txn *sql.Tx,
|
txn *sql.Tx,
|
||||||
token, localpart string,
|
token, localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
expiresAtMS int64,
|
expiresAtMS int64,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
||||||
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
|
_, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,10 +73,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
|
||||||
token string,
|
token string,
|
||||||
) (*api.OpenIDTokenAttributes, error) {
|
) (*api.OpenIDTokenAttributes, error) {
|
||||||
var openIDTokenAttrs api.OpenIDTokenAttributes
|
var openIDTokenAttrs api.OpenIDTokenAttributes
|
||||||
|
var localpart string
|
||||||
|
var serverName gomatrixserverlib.ServerName
|
||||||
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
||||||
&openIDTokenAttrs.UserID,
|
&localpart, &serverName,
|
||||||
&openIDTokenAttrs.ExpiresAtMS,
|
&openIDTokenAttrs.ExpiresAtMS,
|
||||||
)
|
)
|
||||||
|
openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != sql.ErrNoRows {
|
if err != sql.ErrNoRows {
|
||||||
log.WithError(err).Error("Unable to retrieve token from the db")
|
log.WithError(err).Error("Unable to retrieve token from the db")
|
||||||
|
|
|
@ -23,36 +23,40 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
const profilesSchema = `
|
const profilesSchema = `
|
||||||
-- Stores data about accounts profiles.
|
-- Stores data about accounts profiles.
|
||||||
CREATE TABLE IF NOT EXISTS userapi_profiles (
|
CREATE TABLE IF NOT EXISTS userapi_profiles (
|
||||||
-- The Matrix user ID localpart for this account
|
-- The Matrix user ID localpart for this account
|
||||||
localpart TEXT NOT NULL PRIMARY KEY,
|
localpart TEXT NOT NULL,
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
-- The display name for this account
|
-- The display name for this account
|
||||||
display_name TEXT,
|
display_name TEXT,
|
||||||
-- The URL of the avatar for this account
|
-- The URL of the avatar for this account
|
||||||
avatar_url TEXT
|
avatar_url TEXT
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS userapi_profiles_idx ON userapi_profiles(localpart, server_name);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertProfileSQL = "" +
|
const insertProfileSQL = "" +
|
||||||
"INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
|
"INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)"
|
||||||
|
|
||||||
const selectProfileByLocalpartSQL = "" +
|
const selectProfileByLocalpartSQL = "" +
|
||||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
|
"SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2"
|
||||||
|
|
||||||
const setAvatarURLSQL = "" +
|
const setAvatarURLSQL = "" +
|
||||||
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" +
|
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2 AND server_name = $3" +
|
||||||
" RETURNING display_name"
|
" RETURNING display_name"
|
||||||
|
|
||||||
const setDisplayNameSQL = "" +
|
const setDisplayNameSQL = "" +
|
||||||
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" +
|
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2 AND server_name = $3" +
|
||||||
" RETURNING avatar_url"
|
" RETURNING avatar_url"
|
||||||
|
|
||||||
const selectProfilesBySearchSQL = "" +
|
const selectProfilesBySearchSQL = "" +
|
||||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
"SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||||
|
|
||||||
type profilesStatements struct {
|
type profilesStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
@ -83,18 +87,20 @@ func NewSQLiteProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables.P
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) InsertProfile(
|
func (s *profilesStatements) InsertProfile(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) SelectProfileByLocalpart(
|
func (s *profilesStatements) SelectProfileByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (*authtypes.Profile, error) {
|
) (*authtypes.Profile, error) {
|
||||||
var profile authtypes.Profile
|
var profile authtypes.Profile
|
||||||
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
|
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan(
|
||||||
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
|
&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -103,13 +109,16 @@ func (s *profilesStatements) SelectProfileByLocalpart(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) SetAvatarURL(
|
func (s *profilesStatements) SetAvatarURL(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
avatarURL string,
|
||||||
) (*authtypes.Profile, bool, error) {
|
) (*authtypes.Profile, bool, error) {
|
||||||
profile := &authtypes.Profile{
|
profile := &authtypes.Profile{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
AvatarURL: avatarURL,
|
ServerName: string(serverName),
|
||||||
|
AvatarURL: avatarURL,
|
||||||
}
|
}
|
||||||
old, err := s.SelectProfileByLocalpart(ctx, localpart)
|
old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return old, false, err
|
return old, false, err
|
||||||
}
|
}
|
||||||
|
@ -117,18 +126,21 @@ func (s *profilesStatements) SetAvatarURL(
|
||||||
return old, false, nil
|
return old, false, nil
|
||||||
}
|
}
|
||||||
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
||||||
err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName)
|
err = stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName)
|
||||||
return profile, true, err
|
return profile, true, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) SetDisplayName(
|
func (s *profilesStatements) SetDisplayName(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
|
displayName string,
|
||||||
) (*authtypes.Profile, bool, error) {
|
) (*authtypes.Profile, bool, error) {
|
||||||
profile := &authtypes.Profile{
|
profile := &authtypes.Profile{
|
||||||
Localpart: localpart,
|
Localpart: localpart,
|
||||||
|
ServerName: string(serverName),
|
||||||
DisplayName: displayName,
|
DisplayName: displayName,
|
||||||
}
|
}
|
||||||
old, err := s.SelectProfileByLocalpart(ctx, localpart)
|
old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return old, false, err
|
return old, false, err
|
||||||
}
|
}
|
||||||
|
@ -136,7 +148,7 @@ func (s *profilesStatements) SetDisplayName(
|
||||||
return old, false, nil
|
return old, false, nil
|
||||||
}
|
}
|
||||||
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
||||||
err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL)
|
err = stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL)
|
||||||
return profile, true, err
|
return profile, true, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -154,7 +166,7 @@ func (s *profilesStatements) SelectProfilesBySearch(
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var profile authtypes.Profile
|
var profile authtypes.Profile
|
||||||
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if profile.Localpart != s.serverNoticesLocalpart {
|
if profile.Localpart != s.serverNoticesLocalpart {
|
||||||
|
|
|
@ -25,6 +25,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
|
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
|
||||||
|
@ -33,6 +34,7 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
-- The Matrix user ID localpart for this pusher
|
-- The Matrix user ID localpart for this pusher
|
||||||
localpart TEXT NOT NULL,
|
localpart TEXT NOT NULL,
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
session_id BIGINT DEFAULT NULL,
|
session_id BIGINT DEFAULT NULL,
|
||||||
profile_tag TEXT,
|
profile_tag TEXT,
|
||||||
kind TEXT NOT NULL,
|
kind TEXT NOT NULL,
|
||||||
|
@ -49,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
|
||||||
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
|
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
|
||||||
|
|
||||||
-- For faster retrieving by localpart.
|
-- For faster retrieving by localpart.
|
||||||
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
|
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name);
|
||||||
|
|
||||||
-- Pushkey must be unique for a given user and app.
|
-- Pushkey must be unique for a given user and app.
|
||||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
|
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertPusherSQL = "" +
|
const insertPusherSQL = "" +
|
||||||
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
|
"INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
|
||||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
|
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" +
|
||||||
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
|
"ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12"
|
||||||
|
|
||||||
const selectPushersSQL = "" +
|
const selectPushersSQL = "" +
|
||||||
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
|
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2"
|
||||||
|
|
||||||
const deletePusherSQL = "" +
|
const deletePusherSQL = "" +
|
||||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
|
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4"
|
||||||
|
|
||||||
const deletePushersByAppIdAndPushKeySQL = "" +
|
const deletePushersByAppIdAndPushKeySQL = "" +
|
||||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
|
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
|
||||||
|
@ -95,18 +97,19 @@ type pushersStatements struct {
|
||||||
// Returns nil error success.
|
// Returns nil error success.
|
||||||
func (s *pushersStatements) InsertPusher(
|
func (s *pushersStatements) InsertPusher(
|
||||||
ctx context.Context, txn *sql.Tx, session_id int64,
|
ctx context.Context, txn *sql.Tx, session_id int64,
|
||||||
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||||
logrus.Debugf("Created pusher %d", session_id)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *pushersStatements) SelectPushers(
|
func (s *pushersStatements) SelectPushers(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string,
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) ([]api.Pusher, error) {
|
) ([]api.Pusher, error) {
|
||||||
pushers := []api.Pusher{}
|
pushers := []api.Pusher{}
|
||||||
rows, err := s.selectPushersStmt.QueryContext(ctx, localpart)
|
rows, err := s.selectPushersStmt.QueryContext(ctx, localpart, serverName)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return pushers, err
|
return pushers, err
|
||||||
|
@ -143,9 +146,10 @@ func (s *pushersStatements) SelectPushers(
|
||||||
|
|
||||||
// deletePusher removes a single pusher by pushkey and user localpart.
|
// deletePusher removes a single pusher by pushkey and user localpart.
|
||||||
func (s *pushersStatements) DeletePusher(
|
func (s *pushersStatements) DeletePusher(
|
||||||
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
|
ctx context.Context, txn *sql.Tx, appid, pushkey,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
|
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,8 @@
|
||||||
package sqlite3
|
package sqlite3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -41,18 +43,24 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
||||||
Up: deltas.UpRenameTables,
|
Up: deltas.UpRenameTables,
|
||||||
Down: deltas.DownRenameTables,
|
Down: deltas.DownRenameTables,
|
||||||
})
|
})
|
||||||
|
m.AddMigrations(sqlutil.Migration{
|
||||||
|
Version: "userapi: server names",
|
||||||
|
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||||
|
return deltas.UpServerNames(ctx, txn, serverName)
|
||||||
|
},
|
||||||
|
})
|
||||||
if err = m.Up(base.Context()); err != nil {
|
if err = m.Up(base.Context()); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
accountDataTable, err := NewSQLiteAccountDataTable(db)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
|
|
||||||
}
|
|
||||||
accountsTable, err := NewSQLiteAccountsTable(db, serverName)
|
accountsTable, err := NewSQLiteAccountsTable(db, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err)
|
return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err)
|
||||||
}
|
}
|
||||||
|
accountDataTable, err := NewSQLiteAccountDataTable(db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
|
||||||
|
}
|
||||||
devicesTable, err := NewSQLiteDevicesTable(db, serverName)
|
devicesTable, err := NewSQLiteDevicesTable(db, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err)
|
return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err)
|
||||||
|
@ -93,6 +101,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err)
|
return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m = sqlutil.NewMigrator(db)
|
||||||
|
m.AddMigrations(sqlutil.Migration{
|
||||||
|
Version: "userapi: server names populate",
|
||||||
|
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||||
|
return deltas.UpServerNamesPopulate(ctx, txn, serverName)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err = m.Up(base.Context()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return &shared.Database{
|
return &shared.Database{
|
||||||
AccountDatas: accountDataTable,
|
AccountDatas: accountDataTable,
|
||||||
Accounts: accountsTable,
|
Accounts: accountsTable,
|
||||||
|
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
)
|
)
|
||||||
|
@ -34,21 +35,22 @@ CREATE TABLE IF NOT EXISTS userapi_threepids (
|
||||||
medium TEXT NOT NULL DEFAULT 'email',
|
medium TEXT NOT NULL DEFAULT 'email',
|
||||||
-- The localpart of the Matrix user ID associated to this 3PID
|
-- The localpart of the Matrix user ID associated to this 3PID
|
||||||
localpart TEXT NOT NULL,
|
localpart TEXT NOT NULL,
|
||||||
|
server_name TEXT NOT NULL,
|
||||||
|
|
||||||
PRIMARY KEY(threepid, medium)
|
PRIMARY KEY(threepid, medium)
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart);
|
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart, server_name);
|
||||||
`
|
`
|
||||||
|
|
||||||
const selectLocalpartForThreePIDSQL = "" +
|
const selectLocalpartForThreePIDSQL = "" +
|
||||||
"SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
"SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||||
|
|
||||||
const selectThreePIDsForLocalpartSQL = "" +
|
const selectThreePIDsForLocalpartSQL = "" +
|
||||||
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1"
|
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2"
|
||||||
|
|
||||||
const insertThreePIDSQL = "" +
|
const insertThreePIDSQL = "" +
|
||||||
"INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)"
|
"INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)"
|
||||||
|
|
||||||
const deleteThreePIDSQL = "" +
|
const deleteThreePIDSQL = "" +
|
||||||
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||||
|
@ -79,19 +81,20 @@ func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
|
||||||
|
|
||||||
func (s *threepidStatements) SelectLocalpartForThreePID(
|
func (s *threepidStatements) SelectLocalpartForThreePID(
|
||||||
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
||||||
) (localpart string, err error) {
|
) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
||||||
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
|
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return "", nil
|
return "", "", nil
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (threepids []authtypes.ThreePID, err error) {
|
) (threepids []authtypes.ThreePID, err error) {
|
||||||
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
|
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -113,10 +116,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) InsertThreePID(
|
func (s *threepidStatements) InsertThreePID(
|
||||||
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
|
ctx context.Context, txn *sql.Tx, threepid, medium,
|
||||||
|
localpart string, serverName gomatrixserverlib.ServerName,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
||||||
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
|
_, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -50,25 +50,25 @@ func Test_AccountData(t *testing.T) {
|
||||||
db, close := mustCreateDatabase(t, dbType)
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
alice := test.NewUser(t)
|
alice := test.NewUser(t)
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
room := test.NewRoom(t, alice)
|
room := test.NewRoom(t, alice)
|
||||||
events := room.Events()
|
events := room.Events()
|
||||||
|
|
||||||
contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID()))
|
contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID()))
|
||||||
err = db.SaveAccountData(ctx, localpart, room.ID, "m.fully_read", contentRoom)
|
err = db.SaveAccountData(ctx, localpart, domain, room.ID, "m.fully_read", contentRoom)
|
||||||
assert.NoError(t, err, "unable to save account data")
|
assert.NoError(t, err, "unable to save account data")
|
||||||
|
|
||||||
contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID))
|
contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID))
|
||||||
err = db.SaveAccountData(ctx, localpart, "", "im.vector.setting.breadcrumbs", contentGlobal)
|
err = db.SaveAccountData(ctx, localpart, domain, "", "im.vector.setting.breadcrumbs", contentGlobal)
|
||||||
assert.NoError(t, err, "unable to save account data")
|
assert.NoError(t, err, "unable to save account data")
|
||||||
|
|
||||||
accountData, err := db.GetAccountDataByType(ctx, localpart, room.ID, "m.fully_read")
|
accountData, err := db.GetAccountDataByType(ctx, localpart, domain, room.ID, "m.fully_read")
|
||||||
assert.NoError(t, err, "unable to get account data by type")
|
assert.NoError(t, err, "unable to get account data by type")
|
||||||
assert.Equal(t, contentRoom, accountData)
|
assert.Equal(t, contentRoom, accountData)
|
||||||
|
|
||||||
globalData, roomData, err := db.GetAccountData(ctx, localpart)
|
globalData, roomData, err := db.GetAccountData(ctx, localpart, domain)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"])
|
assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"])
|
||||||
assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"])
|
assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"])
|
||||||
|
@ -81,78 +81,78 @@ func Test_Accounts(t *testing.T) {
|
||||||
db, close := mustCreateDatabase(t, dbType)
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
alice := test.NewUser(t)
|
alice := test.NewUser(t)
|
||||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
|
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin)
|
||||||
assert.NoError(t, err, "failed to create account")
|
assert.NoError(t, err, "failed to create account")
|
||||||
// verify the newly create account is the same as returned by CreateAccount
|
// verify the newly create account is the same as returned by CreateAccount
|
||||||
var accGet *api.Account
|
var accGet *api.Account
|
||||||
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing")
|
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "testing")
|
||||||
assert.NoError(t, err, "failed to get account by password")
|
assert.NoError(t, err, "failed to get account by password")
|
||||||
assert.Equal(t, accAlice, accGet)
|
assert.Equal(t, accAlice, accGet)
|
||||||
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart)
|
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart, aliceDomain)
|
||||||
assert.NoError(t, err, "failed to get account by localpart")
|
assert.NoError(t, err, "failed to get account by localpart")
|
||||||
assert.Equal(t, accAlice, accGet)
|
assert.Equal(t, accAlice, accGet)
|
||||||
|
|
||||||
// check account availability
|
// check account availability
|
||||||
available, err := db.CheckAccountAvailability(ctx, aliceLocalpart)
|
available, err := db.CheckAccountAvailability(ctx, aliceLocalpart, aliceDomain)
|
||||||
assert.NoError(t, err, "failed to checkout account availability")
|
assert.NoError(t, err, "failed to checkout account availability")
|
||||||
assert.Equal(t, false, available)
|
assert.Equal(t, false, available)
|
||||||
|
|
||||||
available, err = db.CheckAccountAvailability(ctx, "unusedname")
|
available, err = db.CheckAccountAvailability(ctx, "unusedname", aliceDomain)
|
||||||
assert.NoError(t, err, "failed to checkout account availability")
|
assert.NoError(t, err, "failed to checkout account availability")
|
||||||
assert.Equal(t, true, available)
|
assert.Equal(t, true, available)
|
||||||
|
|
||||||
// get guest account numeric aliceLocalpart
|
// get guest account numeric aliceLocalpart
|
||||||
first, err := db.GetNewNumericLocalpart(ctx)
|
first, err := db.GetNewNumericLocalpart(ctx, aliceDomain)
|
||||||
assert.NoError(t, err, "failed to get new numeric localpart")
|
assert.NoError(t, err, "failed to get new numeric localpart")
|
||||||
// Create a new account to verify the numeric localpart is updated
|
// Create a new account to verify the numeric localpart is updated
|
||||||
_, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest)
|
_, err = db.CreateAccount(ctx, "", aliceDomain, "testing", "", api.AccountTypeGuest)
|
||||||
assert.NoError(t, err, "failed to create account")
|
assert.NoError(t, err, "failed to create account")
|
||||||
second, err := db.GetNewNumericLocalpart(ctx)
|
second, err := db.GetNewNumericLocalpart(ctx, aliceDomain)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Greater(t, second, first)
|
assert.Greater(t, second, first)
|
||||||
|
|
||||||
// update password for alice
|
// update password for alice
|
||||||
err = db.SetPassword(ctx, aliceLocalpart, "newPassword")
|
err = db.SetPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
|
||||||
assert.NoError(t, err, "failed to update password")
|
assert.NoError(t, err, "failed to update password")
|
||||||
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
|
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
|
||||||
assert.NoError(t, err, "failed to get account by new password")
|
assert.NoError(t, err, "failed to get account by new password")
|
||||||
assert.Equal(t, accAlice, accGet)
|
assert.Equal(t, accAlice, accGet)
|
||||||
|
|
||||||
// deactivate account
|
// deactivate account
|
||||||
err = db.DeactivateAccount(ctx, aliceLocalpart)
|
err = db.DeactivateAccount(ctx, aliceLocalpart, aliceDomain)
|
||||||
assert.NoError(t, err, "failed to deactivate account")
|
assert.NoError(t, err, "failed to deactivate account")
|
||||||
// This should fail now, as the account is deactivated
|
// This should fail now, as the account is deactivated
|
||||||
_, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
|
_, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword")
|
||||||
assert.Error(t, err, "expected an error, got none")
|
assert.Error(t, err, "expected an error, got none")
|
||||||
|
|
||||||
_, err = db.GetAccountByLocalpart(ctx, "unusename")
|
_, err = db.GetAccountByLocalpart(ctx, "unusename", aliceDomain)
|
||||||
assert.Error(t, err, "expected an error for non existent localpart")
|
assert.Error(t, err, "expected an error for non existent localpart")
|
||||||
|
|
||||||
// create an empty localpart; this should never happen, but is required to test getting a numeric localpart
|
// create an empty localpart; this should never happen, but is required to test getting a numeric localpart
|
||||||
// if there's already a user without a localpart in the database
|
// if there's already a user without a localpart in the database
|
||||||
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeUser)
|
_, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeUser)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// test getting a numeric localpart, with an existing user without a localpart
|
// test getting a numeric localpart, with an existing user without a localpart
|
||||||
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest)
|
_, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Create a user with a high numeric localpart, out of range for the Postgres integer (2147483647) type
|
// Create a user with a high numeric localpart, out of range for the Postgres integer (2147483647) type
|
||||||
_, err = db.CreateAccount(ctx, "2147483650", "", "", api.AccountTypeUser)
|
_, err = db.CreateAccount(ctx, "2147483650", aliceDomain, "", "", api.AccountTypeUser)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Now try to create a new guest user
|
// Now try to create a new guest user
|
||||||
_, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest)
|
_, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_Devices(t *testing.T) {
|
func Test_Devices(t *testing.T) {
|
||||||
alice := test.NewUser(t)
|
alice := test.NewUser(t)
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
deviceID := util.RandomString(8)
|
deviceID := util.RandomString(8)
|
||||||
accessToken := util.RandomString(16)
|
accessToken := util.RandomString(16)
|
||||||
|
@ -161,10 +161,10 @@ func Test_Devices(t *testing.T) {
|
||||||
db, close := mustCreateDatabase(t, dbType)
|
db, close := mustCreateDatabase(t, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
|
|
||||||
deviceWithID, err := db.CreateDevice(ctx, localpart, &deviceID, accessToken, nil, "", "")
|
deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "")
|
||||||
assert.NoError(t, err, "unable to create deviceWithoutID")
|
assert.NoError(t, err, "unable to create deviceWithoutID")
|
||||||
|
|
||||||
gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID)
|
gotDevice, err := db.GetDeviceByID(ctx, localpart, domain, deviceID)
|
||||||
assert.NoError(t, err, "unable to get device by id")
|
assert.NoError(t, err, "unable to get device by id")
|
||||||
assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields
|
assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields
|
||||||
|
|
||||||
|
@ -174,14 +174,14 @@ func Test_Devices(t *testing.T) {
|
||||||
|
|
||||||
// create a device without existing device ID
|
// create a device without existing device ID
|
||||||
accessToken = util.RandomString(16)
|
accessToken = util.RandomString(16)
|
||||||
deviceWithoutID, err := db.CreateDevice(ctx, localpart, nil, accessToken, nil, "", "")
|
deviceWithoutID, err := db.CreateDevice(ctx, localpart, domain, nil, accessToken, nil, "", "")
|
||||||
assert.NoError(t, err, "unable to create deviceWithoutID")
|
assert.NoError(t, err, "unable to create deviceWithoutID")
|
||||||
gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID)
|
gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, domain, deviceWithoutID.ID)
|
||||||
assert.NoError(t, err, "unable to get device by id")
|
assert.NoError(t, err, "unable to get device by id")
|
||||||
assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields
|
assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields
|
||||||
|
|
||||||
// Get devices
|
// Get devices
|
||||||
devices, err := db.GetDevicesByLocalpart(ctx, localpart)
|
devices, err := db.GetDevicesByLocalpart(ctx, localpart, domain)
|
||||||
assert.NoError(t, err, "unable to get devices by localpart")
|
assert.NoError(t, err, "unable to get devices by localpart")
|
||||||
assert.Equal(t, 2, len(devices))
|
assert.Equal(t, 2, len(devices))
|
||||||
deviceIDs := make([]string, 0, len(devices))
|
deviceIDs := make([]string, 0, len(devices))
|
||||||
|
@ -195,15 +195,15 @@ func Test_Devices(t *testing.T) {
|
||||||
|
|
||||||
// Update device
|
// Update device
|
||||||
newName := "new display name"
|
newName := "new display name"
|
||||||
err = db.UpdateDevice(ctx, localpart, deviceWithID.ID, &newName)
|
err = db.UpdateDevice(ctx, localpart, domain, deviceWithID.ID, &newName)
|
||||||
assert.NoError(t, err, "unable to update device displayname")
|
assert.NoError(t, err, "unable to update device displayname")
|
||||||
updatedAfterTimestamp := time.Now().Unix()
|
updatedAfterTimestamp := time.Now().Unix()
|
||||||
err = db.UpdateDeviceLastSeen(ctx, localpart, deviceWithID.ID, "127.0.0.1", "Element Web")
|
err = db.UpdateDeviceLastSeen(ctx, localpart, domain, deviceWithID.ID, "127.0.0.1", "Element Web")
|
||||||
assert.NoError(t, err, "unable to update device last seen")
|
assert.NoError(t, err, "unable to update device last seen")
|
||||||
|
|
||||||
deviceWithID.DisplayName = newName
|
deviceWithID.DisplayName = newName
|
||||||
deviceWithID.LastSeenIP = "127.0.0.1"
|
deviceWithID.LastSeenIP = "127.0.0.1"
|
||||||
gotDevice, err = db.GetDeviceByID(ctx, localpart, deviceWithID.ID)
|
gotDevice, err = db.GetDeviceByID(ctx, localpart, domain, deviceWithID.ID)
|
||||||
assert.NoError(t, err, "unable to get device by id")
|
assert.NoError(t, err, "unable to get device by id")
|
||||||
assert.Equal(t, 2, len(devices))
|
assert.Equal(t, 2, len(devices))
|
||||||
assert.Equal(t, deviceWithID.DisplayName, gotDevice.DisplayName)
|
assert.Equal(t, deviceWithID.DisplayName, gotDevice.DisplayName)
|
||||||
|
@ -213,20 +213,20 @@ func Test_Devices(t *testing.T) {
|
||||||
// create one more device and remove the devices step by step
|
// create one more device and remove the devices step by step
|
||||||
newDeviceID := util.RandomString(16)
|
newDeviceID := util.RandomString(16)
|
||||||
accessToken = util.RandomString(16)
|
accessToken = util.RandomString(16)
|
||||||
_, err = db.CreateDevice(ctx, localpart, &newDeviceID, accessToken, nil, "", "")
|
_, err = db.CreateDevice(ctx, localpart, domain, &newDeviceID, accessToken, nil, "", "")
|
||||||
assert.NoError(t, err, "unable to create new device")
|
assert.NoError(t, err, "unable to create new device")
|
||||||
|
|
||||||
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
|
devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain)
|
||||||
assert.NoError(t, err, "unable to get device by id")
|
assert.NoError(t, err, "unable to get device by id")
|
||||||
assert.Equal(t, 3, len(devices))
|
assert.Equal(t, 3, len(devices))
|
||||||
|
|
||||||
err = db.RemoveDevices(ctx, localpart, deviceIDs)
|
err = db.RemoveDevices(ctx, localpart, domain, deviceIDs)
|
||||||
assert.NoError(t, err, "unable to remove devices")
|
assert.NoError(t, err, "unable to remove devices")
|
||||||
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
|
devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain)
|
||||||
assert.NoError(t, err, "unable to get device by id")
|
assert.NoError(t, err, "unable to get device by id")
|
||||||
assert.Equal(t, 1, len(devices))
|
assert.Equal(t, 1, len(devices))
|
||||||
|
|
||||||
deleted, err := db.RemoveAllDevices(ctx, localpart, "")
|
deleted, err := db.RemoveAllDevices(ctx, localpart, domain, "")
|
||||||
assert.NoError(t, err, "unable to remove all devices")
|
assert.NoError(t, err, "unable to remove all devices")
|
||||||
assert.Equal(t, 1, len(deleted))
|
assert.Equal(t, 1, len(deleted))
|
||||||
assert.Equal(t, newDeviceID, deleted[0].ID)
|
assert.Equal(t, newDeviceID, deleted[0].ID)
|
||||||
|
@ -364,7 +364,7 @@ func Test_OpenID(t *testing.T) {
|
||||||
|
|
||||||
func Test_Profile(t *testing.T) {
|
func Test_Profile(t *testing.T) {
|
||||||
alice := test.NewUser(t)
|
alice := test.NewUser(t)
|
||||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
@ -372,30 +372,33 @@ func Test_Profile(t *testing.T) {
|
||||||
defer close()
|
defer close()
|
||||||
|
|
||||||
// create account, which also creates a profile
|
// create account, which also creates a profile
|
||||||
_, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
|
_, err = db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin)
|
||||||
assert.NoError(t, err, "failed to create account")
|
assert.NoError(t, err, "failed to create account")
|
||||||
|
|
||||||
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart)
|
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart, aliceDomain)
|
||||||
assert.NoError(t, err, "unable to get profile by localpart")
|
assert.NoError(t, err, "unable to get profile by localpart")
|
||||||
wantProfile := &authtypes.Profile{Localpart: aliceLocalpart}
|
wantProfile := &authtypes.Profile{
|
||||||
|
Localpart: aliceLocalpart,
|
||||||
|
ServerName: string(aliceDomain),
|
||||||
|
}
|
||||||
assert.Equal(t, wantProfile, gotProfile)
|
assert.Equal(t, wantProfile, gotProfile)
|
||||||
|
|
||||||
// set avatar & displayname
|
// set avatar & displayname
|
||||||
wantProfile.DisplayName = "Alice"
|
wantProfile.DisplayName = "Alice"
|
||||||
gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, "Alice")
|
gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, aliceDomain, "Alice")
|
||||||
assert.Equal(t, wantProfile, gotProfile)
|
assert.Equal(t, wantProfile, gotProfile)
|
||||||
assert.NoError(t, err, "unable to set displayname")
|
assert.NoError(t, err, "unable to set displayname")
|
||||||
assert.True(t, changed)
|
assert.True(t, changed)
|
||||||
|
|
||||||
wantProfile.AvatarURL = "mxc://aliceAvatar"
|
wantProfile.AvatarURL = "mxc://aliceAvatar"
|
||||||
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
|
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar")
|
||||||
assert.NoError(t, err, "unable to set avatar url")
|
assert.NoError(t, err, "unable to set avatar url")
|
||||||
assert.Equal(t, wantProfile, gotProfile)
|
assert.Equal(t, wantProfile, gotProfile)
|
||||||
assert.True(t, changed)
|
assert.True(t, changed)
|
||||||
|
|
||||||
// Setting the same avatar again doesn't change anything
|
// Setting the same avatar again doesn't change anything
|
||||||
wantProfile.AvatarURL = "mxc://aliceAvatar"
|
wantProfile.AvatarURL = "mxc://aliceAvatar"
|
||||||
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
|
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar")
|
||||||
assert.NoError(t, err, "unable to set avatar url")
|
assert.NoError(t, err, "unable to set avatar url")
|
||||||
assert.Equal(t, wantProfile, gotProfile)
|
assert.Equal(t, wantProfile, gotProfile)
|
||||||
assert.False(t, changed)
|
assert.False(t, changed)
|
||||||
|
@ -410,7 +413,7 @@ func Test_Profile(t *testing.T) {
|
||||||
|
|
||||||
func Test_Pusher(t *testing.T) {
|
func Test_Pusher(t *testing.T) {
|
||||||
alice := test.NewUser(t)
|
alice := test.NewUser(t)
|
||||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
@ -432,11 +435,11 @@ func Test_Pusher(t *testing.T) {
|
||||||
ProfileTag: util.RandomString(8),
|
ProfileTag: util.RandomString(8),
|
||||||
Language: util.RandomString(2),
|
Language: util.RandomString(2),
|
||||||
}
|
}
|
||||||
err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart)
|
err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart, aliceDomain)
|
||||||
assert.NoError(t, err, "unable to upsert pusher")
|
assert.NoError(t, err, "unable to upsert pusher")
|
||||||
|
|
||||||
// check it was actually persisted
|
// check it was actually persisted
|
||||||
gotPushers, err = db.GetPushers(ctx, aliceLocalpart)
|
gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain)
|
||||||
assert.NoError(t, err, "unable to get pushers")
|
assert.NoError(t, err, "unable to get pushers")
|
||||||
assert.Equal(t, i+1, len(gotPushers))
|
assert.Equal(t, i+1, len(gotPushers))
|
||||||
assert.Equal(t, wantPusher, gotPushers[i])
|
assert.Equal(t, wantPusher, gotPushers[i])
|
||||||
|
@ -444,16 +447,16 @@ func Test_Pusher(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove single pusher
|
// remove single pusher
|
||||||
err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart)
|
err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart, aliceDomain)
|
||||||
assert.NoError(t, err, "unable to remove pusher")
|
assert.NoError(t, err, "unable to remove pusher")
|
||||||
gotPushers, err := db.GetPushers(ctx, aliceLocalpart)
|
gotPushers, err := db.GetPushers(ctx, aliceLocalpart, aliceDomain)
|
||||||
assert.NoError(t, err, "unable to get pushers")
|
assert.NoError(t, err, "unable to get pushers")
|
||||||
assert.Equal(t, 1, len(gotPushers))
|
assert.Equal(t, 1, len(gotPushers))
|
||||||
|
|
||||||
// remove last pusher
|
// remove last pusher
|
||||||
err = db.RemovePushers(ctx, appID, pushKeys[1])
|
err = db.RemovePushers(ctx, appID, pushKeys[1])
|
||||||
assert.NoError(t, err, "unable to remove pusher")
|
assert.NoError(t, err, "unable to remove pusher")
|
||||||
gotPushers, err = db.GetPushers(ctx, aliceLocalpart)
|
gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain)
|
||||||
assert.NoError(t, err, "unable to get pushers")
|
assert.NoError(t, err, "unable to get pushers")
|
||||||
assert.Equal(t, 0, len(gotPushers))
|
assert.Equal(t, 0, len(gotPushers))
|
||||||
})
|
})
|
||||||
|
@ -461,7 +464,7 @@ func Test_Pusher(t *testing.T) {
|
||||||
|
|
||||||
func Test_ThreePID(t *testing.T) {
|
func Test_ThreePID(t *testing.T) {
|
||||||
alice := test.NewUser(t)
|
alice := test.NewUser(t)
|
||||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
@ -469,15 +472,16 @@ func Test_ThreePID(t *testing.T) {
|
||||||
defer close()
|
defer close()
|
||||||
threePID := util.RandomString(8)
|
threePID := util.RandomString(8)
|
||||||
medium := util.RandomString(8)
|
medium := util.RandomString(8)
|
||||||
err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, medium)
|
err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, aliceDomain, medium)
|
||||||
assert.NoError(t, err, "unable to save threepid association")
|
assert.NoError(t, err, "unable to save threepid association")
|
||||||
|
|
||||||
// get the stored threepid
|
// get the stored threepid
|
||||||
gotLocalpart, err := db.GetLocalpartForThreePID(ctx, threePID, medium)
|
gotLocalpart, gotDomain, err := db.GetLocalpartForThreePID(ctx, threePID, medium)
|
||||||
assert.NoError(t, err, "unable to get localpart for threepid")
|
assert.NoError(t, err, "unable to get localpart for threepid")
|
||||||
assert.Equal(t, aliceLocalpart, gotLocalpart)
|
assert.Equal(t, aliceLocalpart, gotLocalpart)
|
||||||
|
assert.Equal(t, aliceDomain, gotDomain)
|
||||||
|
|
||||||
threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart)
|
threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain)
|
||||||
assert.NoError(t, err, "unable to get threepids for localpart")
|
assert.NoError(t, err, "unable to get threepids for localpart")
|
||||||
assert.Equal(t, 1, len(threepids))
|
assert.Equal(t, 1, len(threepids))
|
||||||
assert.Equal(t, authtypes.ThreePID{
|
assert.Equal(t, authtypes.ThreePID{
|
||||||
|
@ -490,7 +494,7 @@ func Test_ThreePID(t *testing.T) {
|
||||||
assert.NoError(t, err, "unexpected error")
|
assert.NoError(t, err, "unexpected error")
|
||||||
|
|
||||||
// verify it was deleted
|
// verify it was deleted
|
||||||
threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart)
|
threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain)
|
||||||
assert.NoError(t, err, "unable to get threepids for localpart")
|
assert.NoError(t, err, "unable to get threepids for localpart")
|
||||||
assert.Equal(t, 0, len(threepids))
|
assert.Equal(t, 0, len(threepids))
|
||||||
})
|
})
|
||||||
|
@ -498,7 +502,7 @@ func Test_ThreePID(t *testing.T) {
|
||||||
|
|
||||||
func Test_Notification(t *testing.T) {
|
func Test_Notification(t *testing.T) {
|
||||||
alice := test.NewUser(t)
|
alice := test.NewUser(t)
|
||||||
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
|
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
room := test.NewRoom(t, alice)
|
room := test.NewRoom(t, alice)
|
||||||
room2 := test.NewRoom(t, alice)
|
room2 := test.NewRoom(t, alice)
|
||||||
|
@ -526,34 +530,34 @@ func Test_Notification(t *testing.T) {
|
||||||
RoomID: roomID,
|
RoomID: roomID,
|
||||||
TS: gomatrixserverlib.AsTimestamp(ts),
|
TS: gomatrixserverlib.AsTimestamp(ts),
|
||||||
}
|
}
|
||||||
err = db.InsertNotification(ctx, aliceLocalpart, eventID, uint64(i+1), nil, notification)
|
err = db.InsertNotification(ctx, aliceLocalpart, aliceDomain, eventID, uint64(i+1), nil, notification)
|
||||||
assert.NoError(t, err, "unable to insert notification")
|
assert.NoError(t, err, "unable to insert notification")
|
||||||
}
|
}
|
||||||
|
|
||||||
// get notifications
|
// get notifications
|
||||||
count, err := db.GetNotificationCount(ctx, aliceLocalpart, tables.AllNotifications)
|
count, err := db.GetNotificationCount(ctx, aliceLocalpart, aliceDomain, tables.AllNotifications)
|
||||||
assert.NoError(t, err, "unable to get notification count")
|
assert.NoError(t, err, "unable to get notification count")
|
||||||
assert.Equal(t, int64(10), count)
|
assert.Equal(t, int64(10), count)
|
||||||
notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, 0, 15, tables.AllNotifications)
|
notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, aliceDomain, 0, 15, tables.AllNotifications)
|
||||||
assert.NoError(t, err, "unable to get notifications")
|
assert.NoError(t, err, "unable to get notifications")
|
||||||
assert.Equal(t, int64(10), count)
|
assert.Equal(t, int64(10), count)
|
||||||
assert.Equal(t, 10, len(notifs))
|
assert.Equal(t, 10, len(notifs))
|
||||||
// ... for a specific room
|
// ... for a specific room
|
||||||
total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
|
total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID)
|
||||||
assert.NoError(t, err, "unable to get notifications for room")
|
assert.NoError(t, err, "unable to get notifications for room")
|
||||||
assert.Equal(t, int64(4), total)
|
assert.Equal(t, int64(4), total)
|
||||||
|
|
||||||
// mark notification as read
|
// mark notification as read
|
||||||
affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, room2.ID, 7, true)
|
affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, aliceDomain, room2.ID, 7, true)
|
||||||
assert.NoError(t, err, "unable to set notifications read")
|
assert.NoError(t, err, "unable to set notifications read")
|
||||||
assert.True(t, affected)
|
assert.True(t, affected)
|
||||||
|
|
||||||
// this should delete 2 notifications
|
// this should delete 2 notifications
|
||||||
affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, room2.ID, 8)
|
affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, aliceDomain, room2.ID, 8)
|
||||||
assert.NoError(t, err, "unable to set notifications read")
|
assert.NoError(t, err, "unable to set notifications read")
|
||||||
assert.True(t, affected)
|
assert.True(t, affected)
|
||||||
|
|
||||||
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
|
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID)
|
||||||
assert.NoError(t, err, "unable to get notifications for room")
|
assert.NoError(t, err, "unable to get notifications for room")
|
||||||
assert.Equal(t, int64(2), total)
|
assert.Equal(t, int64(2), total)
|
||||||
|
|
||||||
|
@ -562,7 +566,7 @@ func Test_Notification(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// this should now return 0 notifications
|
// this should now return 0 notifications
|
||||||
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
|
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID)
|
||||||
assert.NoError(t, err, "unable to get notifications for room")
|
assert.NoError(t, err, "unable to get notifications for room")
|
||||||
assert.Equal(t, int64(0), total)
|
assert.Equal(t, int64(0), total)
|
||||||
})
|
})
|
||||||
|
|
|
@ -28,31 +28,31 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type AccountDataTable interface {
|
type AccountDataTable interface {
|
||||||
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage) error
|
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error
|
||||||
SelectAccountData(ctx context.Context, localpart string) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)
|
SelectAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)
|
||||||
SelectAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
|
SelectAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type AccountsTable interface {
|
type AccountsTable interface {
|
||||||
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
InsertAccount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||||
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error)
|
UpdatePassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, passwordHash string) (err error)
|
||||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error)
|
||||||
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error)
|
SelectPasswordHash(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (hash string, err error)
|
||||||
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
SelectAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error)
|
||||||
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error)
|
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (id int64, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type DevicesTable interface {
|
type DevicesTable interface {
|
||||||
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
|
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
|
||||||
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error
|
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||||
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error
|
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error
|
||||||
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error
|
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) error
|
||||||
UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string) error
|
UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error
|
||||||
SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error)
|
SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error)
|
||||||
SelectDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
|
SelectDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error)
|
||||||
SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) ([]api.Device, error)
|
SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) ([]api.Device, error)
|
||||||
SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
|
SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
|
||||||
UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error
|
UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type KeyBackupTable interface {
|
type KeyBackupTable interface {
|
||||||
|
@ -79,40 +79,40 @@ type LoginTokenTable interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenIDTable interface {
|
type OpenIDTable interface {
|
||||||
InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, expiresAtMS int64) (err error)
|
InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64) (err error)
|
||||||
SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProfileTable interface {
|
type ProfileTable interface {
|
||||||
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error
|
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||||
SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
SelectProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error)
|
||||||
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
|
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error)
|
||||||
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, bool, error)
|
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error)
|
||||||
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ThreePIDTable interface {
|
type ThreePIDTable interface {
|
||||||
SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, err error)
|
SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error)
|
||||||
SelectThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
|
SelectThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error)
|
||||||
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error)
|
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, serverName gomatrixserverlib.ServerName) (err error)
|
||||||
DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error)
|
DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type PusherTable interface {
|
type PusherTable interface {
|
||||||
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
|
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||||
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error)
|
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error)
|
||||||
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error
|
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error
|
||||||
DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error
|
DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type NotificationTable interface {
|
type NotificationTable interface {
|
||||||
Clean(ctx context.Context, txn *sql.Tx) error
|
Clean(ctx context.Context, txn *sql.Tx) error
|
||||||
Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error
|
Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error
|
||||||
DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error)
|
DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error)
|
||||||
UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error)
|
UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error)
|
||||||
Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
|
Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
|
||||||
SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error)
|
SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter NotificationFilter) (int64, error)
|
||||||
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error)
|
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type StatsTable interface {
|
type StatsTable interface {
|
||||||
|
|
|
@ -79,6 +79,7 @@ func mustMakeAccountAndDevice(
|
||||||
accDB tables.AccountsTable,
|
accDB tables.AccountsTable,
|
||||||
devDB tables.DevicesTable,
|
devDB tables.DevicesTable,
|
||||||
localpart string,
|
localpart string,
|
||||||
|
serverName gomatrixserverlib.ServerName, // nolint:unparam
|
||||||
accType api.AccountType,
|
accType api.AccountType,
|
||||||
userAgent string,
|
userAgent string,
|
||||||
) {
|
) {
|
||||||
|
@ -89,11 +90,11 @@ func mustMakeAccountAndDevice(
|
||||||
appServiceID = util.RandomString(16)
|
appServiceID = util.RandomString(16)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := accDB.InsertAccount(ctx, nil, localpart, "", appServiceID, accType)
|
_, err := accDB.InsertAccount(ctx, nil, localpart, serverName, "", appServiceID, accType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to create account: %v", err)
|
t.Fatalf("unable to create account: %v", err)
|
||||||
}
|
}
|
||||||
_, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, util.RandomString(16), nil, "", userAgent)
|
_, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, serverName, util.RandomString(16), nil, "", userAgent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to create device: %v", err)
|
t.Fatalf("unable to create device: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -150,12 +151,12 @@ func Test_UserStatistics(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Want Users", func(t *testing.T) {
|
t.Run("Want Users", func(t *testing.T) {
|
||||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", api.AccountTypeUser, "Element Android")
|
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", "localhost", api.AccountTypeUser, "Element Android")
|
||||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", api.AccountTypeUser, "Element iOS")
|
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", "localhost", api.AccountTypeUser, "Element iOS")
|
||||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", api.AccountTypeUser, "Element web")
|
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", "localhost", api.AccountTypeUser, "Element web")
|
||||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", api.AccountTypeGuest, "Element Electron")
|
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", "localhost", api.AccountTypeGuest, "Element Electron")
|
||||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", api.AccountTypeAdmin, "gecko")
|
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", "localhost", api.AccountTypeAdmin, "gecko")
|
||||||
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", api.AccountTypeAppService, "gecko")
|
mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", "localhost", api.AccountTypeAppService, "gecko")
|
||||||
gotStats, _, err := statsDB.UserStatistics(ctx, nil)
|
gotStats, _, err := statsDB.UserStatistics(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
|
|
@ -80,14 +80,14 @@ func TestQueryProfile(t *testing.T) {
|
||||||
// only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added
|
// only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added
|
||||||
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite)
|
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite)
|
||||||
defer close()
|
defer close()
|
||||||
_, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", api.AccountTypeUser)
|
_, err := accountDB.CreateAccount(context.TODO(), "alice", serverName, "foobar", "", api.AccountTypeUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to make account: %s", err)
|
t.Fatalf("failed to make account: %s", err)
|
||||||
}
|
}
|
||||||
if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil {
|
if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", serverName, aliceAvatarURL); err != nil {
|
||||||
t.Fatalf("failed to set avatar url: %s", err)
|
t.Fatalf("failed to set avatar url: %s", err)
|
||||||
}
|
}
|
||||||
if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil {
|
if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", serverName, aliceDisplayName); err != nil {
|
||||||
t.Fatalf("failed to set display name: %s", err)
|
t.Fatalf("failed to set display name: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -164,7 +164,7 @@ func TestPasswordlessLoginFails(t *testing.T) {
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
_, err := accountDB.CreateAccount(ctx, "auser", "", "", api.AccountTypeAppService)
|
_, err := accountDB.CreateAccount(ctx, "auser", serverName, "", "", api.AccountTypeAppService)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to make account: %s", err)
|
t.Fatalf("failed to make account: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -190,7 +190,7 @@ func TestLoginToken(t *testing.T) {
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
_, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", api.AccountTypeUser)
|
_, err := accountDB.CreateAccount(ctx, "auser", serverName, "apassword", "", api.AccountTypeUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to make account: %s", err)
|
t.Fatalf("failed to make account: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,10 +2,12 @@ package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage"
|
"github.com/matrix-org/dendrite/userapi/storage"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,10 +19,10 @@ type PusherDevice struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPushDevices pushes to the configured devices of a local user.
|
// GetPushDevices pushes to the configured devices of a local user.
|
||||||
func GetPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
|
func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
|
||||||
pushers, err := db.GetPushers(ctx, localpart)
|
pushers, err := db.GetPushers(ctx, localpart, serverName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("db.GetPushers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
devices := make([]*PusherDevice, 0, len(pushers))
|
devices := make([]*PusherDevice, 0, len(pushers))
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage"
|
"github.com/matrix-org/dendrite/userapi/storage"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -16,8 +17,8 @@ import (
|
||||||
// a single goroutine is started when talking to the Push
|
// a single goroutine is started when talking to the Push
|
||||||
// gateways. There is no way to know when the background goroutine has
|
// gateways. There is no way to know when the background goroutine has
|
||||||
// finished.
|
// finished.
|
||||||
func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, db storage.Database) error {
|
func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.Database) error {
|
||||||
pusherDevices, err := GetPushDevices(ctx, localpart, nil, db)
|
pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -26,7 +27,7 @@ func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, loc
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, tables.AllNotifications)
|
userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, serverName, tables.AllNotifications)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue