Test 7 Pushers created with a the same access token are not deleted on password change

This commit is contained in:
Dan Peleg 2021-05-03 21:33:38 +03:00
parent 4c4cf8020a
commit 763354e371
11 changed files with 335 additions and 71 deletions

View file

@ -13,6 +13,7 @@ import (
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
)
type newPasswordRequest struct {
@ -38,6 +39,11 @@ func Password(
var r newPasswordRequest
r.LogoutDevices = true
logrus.WithFields(logrus.Fields{
"sessionId": device.SessionID,
"userId": device.UserID,
}).Debug("Changing password")
// Unmarshal the request.
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
@ -107,6 +113,7 @@ func Password(
// If the request asks us to log out all other devices then
// ask the user API to do that.
if r.LogoutDevices {
logrus.Debug("Logging out devices...")
logoutReq := &userapi.PerformDeviceDeletionRequest{
UserID: device.UserID,
DeviceIDs: nil,
@ -117,6 +124,41 @@ func Password(
util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed")
return jsonerror.InternalServerError()
}
pushersReq := &userapi.QueryPushersRequest{
UserID: device.UserID,
}
pushersRes := &userapi.QueryPushersResponse{}
if err := userAPI.QueryPushers(req.Context(), pushersReq, pushersRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed")
return jsonerror.InternalServerError()
}
var deleted = len(pushersRes.Pushers)
// Deletes all pushers from other devices
for _, pusher := range pushersRes.Pushers {
if pusher.SessionID == device.SessionID {
logrus.Debugf("✅ Skipping pusher %d", pusher.SessionID)
continue
}
if pusher.UserID == device.UserID {
deletionRes := userapi.PerformPusherDeletionResponse{}
if err := userAPI.PerformPusherDeletion(req.Context(), &userapi.PerformPusherDeletionRequest{
AppID: pusher.AppID,
PushKey: pusher.PushKey,
UserID: pusher.UserID,
}, &deletionRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed")
return jsonerror.InternalServerError()
}
logrus.Debugf("💥 Successfully deleted pusher %d", pusher.SessionID)
}
}
if err := userAPI.QueryPushers(req.Context(), pushersReq, pushersRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed")
return jsonerror.InternalServerError()
}
logrus.Debugf("🗑 Deleted %d pushers...", deleted-len(pushersRes.Pushers))
}
// Return a success code.

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/userapi/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
)
// https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
@ -31,7 +32,7 @@ type pusherJSON struct {
AppID string `json:"app_id"`
AppDisplayName string `json:"app_display_name"`
DeviceDisplayName string `json:"device_display_name"`
ProfileTag *string `json:"profile_tag"`
ProfileTag string `json:"profile_tag"`
Language string `json:"lang"`
Data pusherDataJSON `json:"data"`
}
@ -67,12 +68,14 @@ func GetPushersByLocalpart(
AppID: pusher.AppID,
AppDisplayName: pusher.AppDisplayName,
DeviceDisplayName: pusher.DeviceDisplayName,
ProfileTag: &pusher.ProfileTag,
ProfileTag: pusher.ProfileTag,
Language: pusher.Language,
Data: pusherDataJSON(pusher.Data),
})
}
logrus.Debugf("😁 HTTP returning %d pushers", len(res.Pushers))
logrus.Debugf("🔮 Pushers %v", res.Pushers)
return util.JSONResponse{
Code: http.StatusOK,
JSON: res,
@ -85,21 +88,92 @@ func GetPushersByLocalpart(
func SetPusherByLocalpart(
req *http.Request, userAPI userapi.UserInternalAPI, device *api.Device,
) util.JSONResponse {
var deletionRes userapi.PerformPusherDeletionResponse
body := pusherJSON{}
if resErr := httputil.UnmarshalJSONRequest(req, &body); resErr != nil {
return *resErr
}
// TODO:
// 1. if kind == null, GetPusherByPushkey and delete it! 🗑
// 2. if GetPusherByPushkey returns existing Pusher, update it with the received body
// 3. if GetPusherByPushkey returns nothing, create a new Pusher with the received body
var queryRes userapi.QueryPushersResponse
err := userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{
UserID: device.UserID,
}, &queryRes)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed")
return jsonerror.InternalServerError()
}
var targetPusher *userapi.Pusher
for _, pusher := range queryRes.Pushers {
if pusher.PushKey == body.PushKey {
targetPusher = &pusher
break
}
}
// No Pusher exists with the given PushKey for current user
if targetPusher == nil {
// Create a new Pusher for current user
var pusherResponse userapi.PerformPusherCreationResponse
err = userAPI.PerformPusherCreation(req.Context(), &userapi.PerformPusherCreationRequest{
Device: device,
PushKey: body.PushKey,
Kind: body.Kind,
AppID: body.AppID,
AppDisplayName: body.AppDisplayName,
DeviceDisplayName: body.DeviceDisplayName,
ProfileTag: body.ProfileTag,
Language: body.Language,
URL: body.Data.URL,
Format: body.Data.Format,
}, &pusherResponse)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("PerformPusherCreation failed")
return jsonerror.InternalServerError()
}
} else if body.Kind == "" {
if targetPusher == nil {
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.NotFound("Unknown pusher"),
}
}
// if kind is null, delete the pusher! 🗑
err = userAPI.PerformPusherDeletion(req.Context(), &userapi.PerformPusherDeletionRequest{
AppID: targetPusher.AppID,
PushKey: targetPusher.PushKey,
UserID: targetPusher.UserID,
}, &deletionRes)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed")
return jsonerror.InternalServerError()
}
} else {
var pusherResponse userapi.PerformPusherUpdateResponse
err = userAPI.PerformPusherUpdate(req.Context(), &userapi.PerformPusherUpdateRequest{
PushKey: body.PushKey,
Kind: body.Kind,
AppID: body.AppID,
AppDisplayName: body.AppDisplayName,
DeviceDisplayName: body.DeviceDisplayName,
ProfileTag: body.ProfileTag,
Language: body.Language,
URL: body.Data.URL,
Format: body.Data.Format,
}, &pusherResponse)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("PerformPusherUpdate failed")
return jsonerror.InternalServerError()
}
}
res := body
return util.JSONResponse{
Code: http.StatusOK,
JSON: res,
JSON: struct{}{},
}
}

View file

@ -29,10 +29,11 @@ type UserInternalAPI interface {
PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error
PerformPusherCreation(ctx context.Context, req *PerformPusherCreationRequest, res *PerformPusherCreationResponse) error
PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *PerformPusherDeletionResponse) error
PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
PerformPusherCreation(ctx context.Context, req *PerformPusherCreationRequest, res *PerformPusherCreationResponse) error
PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *PerformPusherDeletionResponse) error
PerformPusherUpdate(ctx context.Context, req *PerformPusherUpdateRequest, res *PerformPusherUpdateResponse) error
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
@ -81,6 +82,7 @@ type PerformDeviceDeletionResponse struct {
}
type PerformPusherDeletionRequest struct {
AppID string
PushKey string
UserID string
}
@ -242,7 +244,7 @@ type PerformDeviceCreationResponse struct {
// PerformPusherCreationRequest is the request for PerformPusherCreation
type PerformPusherCreationRequest struct {
Localpart string
Device *Device
PushKey string
Kind string
AppID string
@ -258,6 +260,24 @@ type PerformPusherCreationRequest struct {
type PerformPusherCreationResponse struct {
}
// PerformPusherUpdateRequest is the request for PerformPusherUpdate
type PerformPusherUpdateRequest struct {
Device *Device
PushKey string
Kind string
AppID string
AppDisplayName string
DeviceDisplayName string
ProfileTag string
Language string
URL string
Format string
}
// PerformPusherUpdateResponse is the response for PerformPusherUpdate
type PerformPusherUpdateResponse struct {
}
// PerformAccountDeactivationRequest is the request for PerformAccountDeactivation
type PerformAccountDeactivationRequest struct {
Localpart string
@ -312,6 +332,7 @@ type Device struct {
// Pusher represents a push notification subscriber
type Pusher struct {
UserID string
SessionID int64
PushKey string
Kind string
AppID string

View file

@ -153,16 +153,34 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
func (a *UserInternalAPI) PerformPusherCreation(ctx context.Context, req *api.PerformPusherCreationRequest, res *api.PerformPusherCreationResponse) error {
util.GetLogger(ctx).WithFields(logrus.Fields{
"localpart": req.Localpart,
"userId": req.Device.UserID,
"pushkey": req.PushKey,
"display_name": req.AppDisplayName,
}).Info("PerformPusherCreation")
err := a.PusherDB.CreatePusher(ctx, req.PushKey, req.Kind, req.AppID, req.AppDisplayName, req.DeviceDisplayName, req.ProfileTag, req.Language, req.URL, req.Format, req.Localpart)
local, _, err := gomatrixserverlib.SplitID('@', req.Device.UserID)
if err != nil {
return err
}
err = a.PusherDB.CreatePusher(ctx, req.Device.SessionID, req.PushKey, req.Kind, req.AppID, req.AppDisplayName, req.DeviceDisplayName, req.ProfileTag, req.Language, req.URL, req.Format, local)
return err
}
func (a *UserInternalAPI) PerformPusherUpdate(ctx context.Context, req *api.PerformPusherUpdateRequest, res *api.PerformPusherUpdateResponse) error {
util.GetLogger(ctx).WithFields(logrus.Fields{
"localpart": req.Device.UserID,
"pushkey": req.PushKey,
"display_name": req.AppDisplayName,
}).Info("PerformPusherUpdate")
local, _, err := gomatrixserverlib.SplitID('@', req.Device.UserID)
if err != nil {
return err
}
err = a.PusherDB.UpdatePusher(ctx, req.PushKey, req.Kind, req.AppID, req.AppDisplayName, req.DeviceDisplayName, req.ProfileTag, req.Language, req.URL, req.Format, local)
return err
}
func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *api.PerformPusherDeletionResponse) error {
util.GetLogger(ctx).WithField("user_id", req.UserID).WithField("pushkey", req.PushKey).Info("PerformPusherDeletion")
util.GetLogger(ctx).WithField("user_id", req.UserID).WithField("pushkey", req.PushKey).WithField("app_id", req.AppID).Info("PerformPusherDeletion")
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
return err
@ -170,11 +188,10 @@ func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.Pe
if domain != a.ServerName {
return fmt.Errorf("cannot PerformPusherDeletion of remote users: got %s want %s", domain, a.ServerName)
}
err = a.PusherDB.RemovePusher(ctx, local, req.PushKey)
err = a.PusherDB.RemovePusher(ctx, req.AppID, req.PushKey, local)
if err != nil {
return err
}
// create empty device keys and upload them to delete what was once there and trigger device list changes
return nil
}

View file

@ -36,6 +36,7 @@ const (
PerformPusherDeletionPath = "/userapi/performPusherDeletion"
PerformLastSeenUpdatePath = "/userapi/performLastSeenUpdate"
PerformDeviceUpdatePath = "/userapi/performDeviceUpdate"
PerformPusherUpdatePath = "/userapi/performPusherUpdate"
PerformAccountDeactivationPath = "/userapi/performAccountDeactivation"
PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation"
@ -137,6 +138,18 @@ func (h *httpUserInternalAPI) PerformPusherCreation(
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) PerformPusherUpdate(
ctx context.Context,
request *api.PerformPusherUpdateRequest,
response *api.PerformPusherUpdateResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherUpdate")
defer span.Finish()
apiURL := h.apiURL + PerformPusherUpdatePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) PerformPusherDeletion(
ctx context.Context,
request *api.PerformPusherDeletionRequest,

View file

@ -65,6 +65,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformPusherUpdatePath,
httputil.MakeInternalAPI("performPusherUpdate", func(req *http.Request) util.JSONResponse {
request := api.PerformPusherUpdateRequest{}
response := api.PerformPusherUpdateResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformPusherUpdate(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformDeviceCreationPath,
httputil.MakeInternalAPI("performDeviceCreation", func(req *http.Request) util.JSONResponse {
request := api.PerformDeviceCreationRequest{}

View file

@ -21,8 +21,9 @@ import (
)
type Database interface {
CreatePusher(ctd context.Context, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart string) error
CreatePusher(ctx context.Context, sessionId int64, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart string) error
GetPushersByLocalpart(ctx context.Context, localpart string) ([]api.Pusher, error)
GetPusherByPushkey(ctx context.Context, pushkey, localpart string) (*api.Pusher, error)
RemovePusher(ctx context.Context, pushkey, localpart string) error
UpdatePusher(ctx context.Context, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart string) error
RemovePusher(ctx context.Context, appId, pushkey, localpart string) error
}

View file

@ -23,19 +23,19 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
const pushersSchema = `
-- Stores data about pushers.
CREATE TABLE IF NOT EXISTS pusher_pushers (
id SERIAL PRIMARY KEY,
-- The Matrix user ID localpart for this pusher
localpart TEXT NOT NULL PRIMARY KEY,
-- This is a unique identifier for this pusher.
-- The value you should use for this is the routing or destination address information for the notification, for example,
-- the APNS token for APNS or the Registration ID for GCM. If your notification client has no such concept, use any unique identifier.
-- If the kind is "email", this is the email address to send notifications to.
-- Max length, 512 bytes.
pushkey VARCHAR(512) NOT NULL,
localpart TEXT NOT NULL,
-- The Session ID used to create the Pusher
session_id BIGINT DEFAULT NULL,
-- This string determines which set of device specific rules this pusher executes.
profile_tag TEXT NOT NULL,
-- The kind of pusher. "http" is a pusher that sends HTTP pokes.
kind TEXT,
-- This is a reverse-DNS style identifier for the application. Max length, 64 chars.
@ -44,8 +44,12 @@ CREATE TABLE IF NOT EXISTS pusher_pushers (
app_display_name TEXT,
-- A string that will allow the user to identify what device owns this pusher.
device_display_name TEXT,
-- This string determines which set of device specific rules this pusher executes.
profile_tag TEXT,
-- This is a unique identifier for this pusher.
-- The value you should use for this is the routing or destination address information for the notification, for example,
-- the APNS token for APNS or the Registration ID for GCM. If your notification client has no such concept, use any unique identifier.
-- If the kind is "email", this is the email address to send notifications to.
-- Max length, 512 bytes.
pushkey VARCHAR(512) NOT NULL,
-- The preferred language for receiving notifications (e.g. 'en' or 'en-US')
lang TEXT,
-- Required if kind is http. The URL to use to send notifications to.
@ -55,26 +59,29 @@ CREATE TABLE IF NOT EXISTS pusher_pushers (
);
-- Pushkey must be unique for a given user.
CREATE UNIQUE INDEX IF NOT EXISTS pusher_localpart_pushkey_idx ON pusher_pushers(localpart, pushkey);
CREATE UNIQUE INDEX IF NOT EXISTS pusher_app_id_pushkey_localpart_idx ON pusher_pushers(app_id, pushkey, localpart);
`
const insertPusherSQL = "" +
"INSERT INTO pusher_pushers(localpart, pushkey, kind, app_id, app_display_name, device_display_name, profile_tag, lang, url, format) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"
"INSERT INTO pusher_pushers(localpart, session_id, pushkey, kind, app_id, app_display_name, device_display_name, profile_tag, lang, url, format) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)"
const selectPushersByLocalpartSQL = "" +
"SELECT pushkey, kind, app_id, app_display_name, device_display_name, profile_tag, lang, url, format FROM pusher_pushers WHERE localpart = $1"
"SELECT session_id, pushkey, kind, app_id, app_display_name, device_display_name, profile_tag, lang, url, format FROM pusher_pushers WHERE localpart = $1"
const selectPusherByPushkeySQL = "" +
"SELECT pushkey, kind, app_id, app_display_name, device_display_name, profile_tag, lang, url, format FROM pusher_pushers WHERE localpart = $1 AND pushkey = $2"
"SELECT session_id, pushkey, kind, app_id, app_display_name, device_display_name, profile_tag, lang, url, format FROM pusher_pushers WHERE localpart = $1 AND pushkey = $2"
const updatePusherSQL = "" +
"UPDATE pusher_pushers SET kind = $1, app_id = $2, app_display_name = $3, device_display_name = $4, profile_tag = $5, lang = $6, url = $7, format = $8 WHERE localpart = $9 AND pushkey = $10"
const deletePusherSQL = "" +
"DELETE FROM pusher_pushers WHERE pushkey = $1 AND localpart = $2"
"DELETE FROM pusher_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
type pushersStatements struct {
insertPusherStmt *sql.Stmt
selectPushersByLocalpartStmt *sql.Stmt
selectPusherByPushkeyStmt *sql.Stmt
updatePusherStmt *sql.Stmt
deletePusherStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
@ -94,6 +101,9 @@ func (s *pushersStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
if s.selectPusherByPushkeyStmt, err = db.Prepare(selectPusherByPushkeySQL); err != nil {
return
}
if s.updatePusherStmt, err = db.Prepare(updatePusherSQL); err != nil {
return
}
if s.deletePusherStmt, err = db.Prepare(deletePusherSQL); err != nil {
return
}
@ -105,19 +115,12 @@ func (s *pushersStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
// Returns an error if the user already has a pusher with the given pusher pushkey.
// Returns nil error success.
func (s *pushersStatements) insertPusher(
ctx context.Context, txn *sql.Tx, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart string,
ctx context.Context, txn *sql.Tx, session_id int64,
pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart string,
) error {
stmt := sqlutil.TxStmt(txn, s.insertPusherStmt)
_, err := stmt.ExecContext(ctx, localpart, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format)
return err
}
// deletePusher removes a single pusher by pushkey and user localpart.
func (s *pushersStatements) deletePusher(
ctx context.Context, txn *sql.Tx, pushkey, localpart string,
) error {
stmt := sqlutil.TxStmt(txn, s.deletePusherStmt)
_, err := stmt.ExecContext(ctx, pushkey, localpart)
_, err := stmt.ExecContext(ctx, localpart, session_id, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format)
logrus.Debugf("🥳 Created pusher %d", session_id)
return err
}
@ -134,11 +137,15 @@ func (s *pushersStatements) selectPushersByLocalpart(
for rows.Next() {
var pusher api.Pusher
var sessionid sql.NullInt64
var pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format sql.NullString
err = rows.Scan(&pushkey, &kind, &appid, &appdisplayname, &devicedisplayname, &profiletag, &lang, &url, &format)
err = rows.Scan(&sessionid, &pushkey, &kind, &appid, &appdisplayname, &devicedisplayname, &profiletag, &lang, &url, &format)
if err != nil {
return pushers, err
}
if sessionid.Valid {
pusher.SessionID = sessionid.Int64
}
if pushkey.Valid {
pusher.PushKey = pushkey.String
}
@ -171,6 +178,7 @@ func (s *pushersStatements) selectPushersByLocalpart(
pushers = append(pushers, pusher)
}
logrus.Debugf("🤓 Database returned %d pushers", len(pushers))
return pushers, rows.Err()
}
@ -178,12 +186,16 @@ func (s *pushersStatements) selectPusherByPushkey(
ctx context.Context, localpart, pushkey string,
) (*api.Pusher, error) {
var pusher api.Pusher
var id, key, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format sql.NullString
var sessionid sql.NullInt64
var key, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format sql.NullString
stmt := s.selectPusherByPushkeyStmt
err := stmt.QueryRowContext(ctx, localpart, pushkey).Scan(&id, &key, &kind, &appid, &appdisplayname, &devicedisplayname, &profiletag, &lang, &url, &format)
err := stmt.QueryRowContext(ctx, localpart, pushkey).Scan(&sessionid, &key, &kind, &appid, &appdisplayname, &devicedisplayname, &profiletag, &lang, &url, &format)
if err == nil {
if sessionid.Valid {
pusher.SessionID = sessionid.Int64
}
if key.Valid {
pusher.PushKey = key.String
}
@ -217,3 +229,20 @@ func (s *pushersStatements) selectPusherByPushkey(
return &pusher, err
}
func (s *pushersStatements) updatePusher(
ctx context.Context, txn *sql.Tx, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart string,
) error {
stmt := sqlutil.TxStmt(txn, s.updatePusherStmt)
_, err := stmt.ExecContext(ctx, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart, pushkey)
return err
}
// deletePusher removes a single pusher by pushkey and user localpart.
func (s *pushersStatements) deletePusher(
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
) error {
stmt := sqlutil.TxStmt(txn, s.deletePusherStmt)
_, err := stmt.ExecContext(ctx, appid, pushkey, localpart)
return err
}

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
// Database represents a pusher database.
@ -56,9 +57,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
}
func (d *Database) CreatePusher(
ctx context.Context, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart string,
ctx context.Context, session_id int64,
pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart string,
) error {
return d.pushers.insertPusher(ctx, nil, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart)
return d.pushers.insertPusher(ctx, nil, session_id, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart)
}
// GetPushersByLocalpart returns the pushers matching the given localpart.
@ -75,16 +77,28 @@ func (d *Database) GetPusherByPushkey(
return d.pushers.selectPusherByPushkey(ctx, localpart, pushkey)
}
// UpdatePusher updates the given pusher with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdatePusher(
ctx context.Context, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.pushers.updatePusher(ctx, txn, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart)
})
}
// RemovePusher revokes a pusher by deleting the entry in the database
// matching with the given pushkey and user ID localpart.
// If the pusher doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemovePusher(
ctx context.Context, pushkey, localpart string,
ctx context.Context, appid, pushkey, localpart string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.pushers.deletePusher(ctx, txn, pushkey, localpart); err != sql.ErrNoRows {
if err := d.pushers.deletePusher(ctx, txn, appid, pushkey, localpart); err != sql.ErrNoRows {
return err
} else {
logrus.WithError(err).Debug("RemovePusher Error")
}
return nil
})

View file

@ -20,6 +20,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/gomatrixserverlib"
@ -28,32 +29,36 @@ import (
const pushersSchema = `
-- Stores data about pushers.
CREATE TABLE IF NOT EXISTS pusher_pushers (
localpart TEXT PRIMARY KEY,
pushkey VARCHAR(512),
localpart TEXT,
session_id BIGINT,
profile_tag TEXT,
kind TEXT,
app_id VARCHAR(64),
app_display_name TEXT,
device_display_name TEXT,
profile_tag TEXT,
pushkey VARCHAR(512),
lang TEXT,
url TEXT,
format TEXT,
UNIQUE (localpart, pushkey)
UNIQUE (app_id, pushkey, localpart)
);
`
const insertPusherSQL = "" +
"INSERT INTO pusher_pushers (localpart, pushkey, kind, app_id, app_display_name, device_display_name, profile_tag, lang, url, format)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9. $10)"
"INSERT INTO pusher_pushers (localpart, session_id, pushkey, kind, app_id, app_display_name, device_display_name, profile_tag, lang, url, format)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)"
const selectPushersByLocalpartSQL = "" +
"SELECT pushkey, kind, app_id, app_display_name, device_display_name, profile_tag, lang, url, format FROM pusher_pushers WHERE localpart = $1"
"SELECT session_id, pushkey, kind, app_id, app_display_name, device_display_name, profile_tag, lang, url, format FROM pusher_pushers WHERE localpart = $1"
const selectPusherByPushkeySQL = "" +
"SELECT pushkey, kind, app_id, app_display_name, device_display_name, profile_tag, lang, url, format FROM pusher_pushers WHERE localpart = $1 AND pushkey = $2"
"SELECT session_id, pushkey, kind, app_id, app_display_name, device_display_name, profile_tag, lang, url, format FROM pusher_pushers WHERE localpart = $1 AND pushkey = $2"
const updatePusherSQL = "" +
"UPDATE pusher_pushers SET kind = $1, app_id = $2, app_display_name = $3, device_display_name = $4, profile_tag = $5, lang = $6, url = $7, format = $8 WHERE localpart = $9 AND pushkey = $10"
const deletePusherSQL = "" +
"DELETE FROM pusher_pushers WHERE pushkey = $1 AND localpart = $2"
"DELETE FROM pusher_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
type pushersStatements struct {
db *sql.DB
@ -61,6 +66,7 @@ type pushersStatements struct {
insertPusherStmt *sql.Stmt
selectPushersByLocalpartStmt *sql.Stmt
selectPusherByPushkeyStmt *sql.Stmt
updatePusherStmt *sql.Stmt
deletePusherStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
@ -82,6 +88,9 @@ func (s *pushersStatements) prepare(db *sql.DB, writer sqlutil.Writer, server go
if s.selectPusherByPushkeyStmt, err = db.Prepare(selectPusherByPushkeySQL); err != nil {
return
}
if s.updatePusherStmt, err = db.Prepare(updatePusherSQL); err != nil {
return
}
if s.deletePusherStmt, err = db.Prepare(deletePusherSQL); err != nil {
return
}
@ -93,10 +102,12 @@ func (s *pushersStatements) prepare(db *sql.DB, writer sqlutil.Writer, server go
// Returns an error if the user already has a pusher with the given pusher pushkey.
// Returns nil error success.
func (s *pushersStatements) insertPusher(
ctx context.Context, txn *sql.Tx, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart string,
ctx context.Context, txn *sql.Tx, session_id int64,
pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart string,
) error {
stmt := sqlutil.TxStmt(txn, s.insertPusherStmt)
_, err := stmt.ExecContext(ctx, localpart, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format)
_, err := stmt.ExecContext(ctx, localpart, session_id, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format)
logrus.Debugf("🥳 Created pusher %d", session_id)
return err
}
@ -111,12 +122,17 @@ func (s *pushersStatements) selectPushersByLocalpart(
}
for rows.Next() {
logrus.Debug("Next pusher row...")
var pusher api.Pusher
var sessionid sql.NullInt64
var pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format sql.NullString
err = rows.Scan(&pushkey, &kind, &appid, &appdisplayname, &devicedisplayname, &profiletag, &lang, &url, &format)
err = rows.Scan(&sessionid, &pushkey, &kind, &appid, &appdisplayname, &devicedisplayname, &profiletag, &lang, &url, &format)
if err != nil {
return pushers, err
}
if sessionid.Valid {
pusher.SessionID = sessionid.Int64
}
if pushkey.Valid {
pusher.PushKey = pushkey.String
}
@ -149,6 +165,7 @@ func (s *pushersStatements) selectPushersByLocalpart(
pushers = append(pushers, pusher)
}
logrus.Debugf("🤓 Database returned %d pushers", len(pushers))
return pushers, nil
}
@ -194,10 +211,18 @@ func (s *pushersStatements) selectPusherByPushkey(
return &pusher, err
}
func (s *pushersStatements) deletePusher(
ctx context.Context, txn *sql.Tx, id, localpart string,
func (s *pushersStatements) updatePusher(
ctx context.Context, txn *sql.Tx, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart string,
) error {
stmt := sqlutil.TxStmt(txn, s.deletePusherStmt)
_, err := stmt.ExecContext(ctx, id, localpart)
stmt := sqlutil.TxStmt(txn, s.updatePusherStmt)
_, err := stmt.ExecContext(ctx, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart, pushkey)
return err
}
func (s *pushersStatements) deletePusher(
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
) error {
stmt := sqlutil.TxStmt(txn, s.deletePusherStmt)
_, err := stmt.ExecContext(ctx, appid, pushkey, localpart)
return err
}

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
_ "github.com/mattn/go-sqlite3"
)
@ -58,9 +59,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
}
func (d *Database) CreatePusher(
ctx context.Context, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart string,
ctx context.Context, session_id int64,
pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart string,
) error {
return d.pushers.insertPusher(ctx, nil, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart)
return d.pushers.insertPusher(ctx, nil, session_id, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart)
}
// GetPushersByLocalpart returns the pushers matching the given localpart.
@ -77,16 +79,29 @@ func (d *Database) GetPusherByPushkey(
return d.pushers.selectPusherByPushkey(ctx, pushkey, localpart)
}
// UpdatePusher updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdatePusher(
ctx context.Context, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.pushers.updatePusher(ctx, txn, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart)
})
}
// RemovePusher revokes a pusher by deleting the entry in the database
// matching with the given pusher pushkey and user ID localpart.
// If the pusher doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemovePusher(
ctx context.Context, pushkey, localpart string,
ctx context.Context, appid, pushkey, localpart string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.pushers.deletePusher(ctx, txn, pushkey, localpart); err != sql.ErrNoRows {
if err := d.pushers.deletePusher(ctx, txn, appid, pushkey, localpart); err != sql.ErrNoRows {
logrus.WithError(err).Debug("RemovePusher Yes Error")
return err
} else {
logrus.WithError(err).Debug("RemovePusher No Error")
}
return nil
})