diff --git a/clientapi/routing/pusher.go b/clientapi/routing/pusher.go index e1550cda..26668218 100644 --- a/clientapi/routing/pusher.go +++ b/clientapi/routing/pusher.go @@ -31,7 +31,7 @@ type pusherJSON struct { AppID string `json:"app_id"` AppDisplayName string `json:"app_display_name"` DeviceDisplayName string `json:"device_display_name"` - ProfileTag string `json:"profile_tag"` + ProfileTag *string `json:"profile_tag"` Language string `json:"lang"` Data pusherDataJSON `json:"data"` } @@ -67,7 +67,7 @@ func GetPushersByLocalpart( AppID: pusher.AppID, AppDisplayName: pusher.AppDisplayName, DeviceDisplayName: pusher.DeviceDisplayName, - ProfileTag: pusher.ProfileTag, + ProfileTag: &pusher.ProfileTag, Language: pusher.Language, Data: pusherDataJSON(pusher.Data), }) diff --git a/userapi/api/api.go b/userapi/api/api.go index b2b9e98e..7aa66ecc 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -29,6 +29,7 @@ type UserInternalAPI interface { PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error + PerformPusherCreation(ctx context.Context, req *PerformPusherCreationRequest, res *PerformPusherCreationResponse) error PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *PerformPusherDeletionResponse) error PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error @@ -239,6 +240,24 @@ type PerformDeviceCreationResponse struct { Device *Device } +// PerformPusherCreationRequest is the request for PerformPusherCreation +type PerformPusherCreationRequest struct { + Localpart string + PushKey string + Kind string + AppID string + AppDisplayName string + DeviceDisplayName string + ProfileTag string + Language string + URL string + Format string +} + +// PerformPusherCreationResponse is the response for PerformPusherCreation +type PerformPusherCreationResponse struct { +} + // PerformAccountDeactivationRequest is the request for PerformAccountDeactivation type PerformAccountDeactivationRequest struct { Localpart string diff --git a/userapi/internal/api.go b/userapi/internal/api.go index da4b91a3..dc96a2e2 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -151,6 +151,16 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe return a.deviceListUpdate(req.UserID, deletedDeviceIDs) } +func (a *UserInternalAPI) PerformPusherCreation(ctx context.Context, req *api.PerformPusherCreationRequest, res *api.PerformPusherCreationResponse) error { + util.GetLogger(ctx).WithFields(logrus.Fields{ + "localpart": req.Localpart, + "pushkey": req.PushKey, + "display_name": req.AppDisplayName, + }).Info("PerformPusherCreation") + err := a.PusherDB.CreatePusher(ctx, req.PushKey, req.Kind, req.AppID, req.AppDisplayName, req.DeviceDisplayName, req.ProfileTag, req.Language, req.URL, req.Format, req.Localpart) + return err +} + func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *api.PerformPusherDeletionResponse) error { util.GetLogger(ctx).WithField("user_id", req.UserID).WithField("pushkey", req.PushKey).Info("PerformPusherDeletion") local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index b141e484..cd0c4ca2 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -29,6 +29,7 @@ const ( InputAccountDataPath = "/userapi/inputAccountData" PerformDeviceCreationPath = "/userapi/performDeviceCreation" + PerformPusherCreationPath = "/userapi/performPusherCreation" PerformAccountCreationPath = "/userapi/performAccountCreation" PerformPasswordUpdatePath = "/userapi/performPasswordUpdate" PerformDeviceDeletionPath = "/userapi/performDeviceDeletion" @@ -124,6 +125,18 @@ func (h *httpUserInternalAPI) PerformDeviceDeletion( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } +func (h *httpUserInternalAPI) PerformPusherCreation( + ctx context.Context, + request *api.PerformPusherCreationRequest, + response *api.PerformPusherCreationResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherCreation") + defer span.Finish() + + apiURL := h.apiURL + PerformPusherCreationPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + func (h *httpUserInternalAPI) PerformPusherDeletion( ctx context.Context, request *api.PerformPusherDeletionRequest, diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 97cddf9d..563c6d69 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -52,6 +52,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(PerformPusherCreationPath, + httputil.MakeInternalAPI("performPusherCreation", func(req *http.Request) util.JSONResponse { + request := api.PerformPusherCreationRequest{} + response := api.PerformPusherCreationResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformPusherCreation(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle(PerformDeviceCreationPath, httputil.MakeInternalAPI("performDeviceCreation", func(req *http.Request) util.JSONResponse { request := api.PerformDeviceCreationRequest{} diff --git a/userapi/storage/pushers/interface.go b/userapi/storage/pushers/interface.go index f6559d5a..0a5b1856 100644 --- a/userapi/storage/pushers/interface.go +++ b/userapi/storage/pushers/interface.go @@ -21,6 +21,7 @@ import ( ) type Database interface { + CreatePusher(ctd context.Context, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart string) error GetPushersByLocalpart(ctx context.Context, localpart string) ([]api.Pusher, error) GetPusherByPushkey(ctx context.Context, pushkey, localpart string) (*api.Pusher, error) RemovePusher(ctx context.Context, pushkey, localpart string) error diff --git a/userapi/storage/pushers/postgres/pushers_table.go b/userapi/storage/pushers/postgres/pushers_table.go index b4d5e16f..59e388fd 100644 --- a/userapi/storage/pushers/postgres/pushers_table.go +++ b/userapi/storage/pushers/postgres/pushers_table.go @@ -58,15 +58,21 @@ CREATE TABLE IF NOT EXISTS pusher_pushers ( CREATE UNIQUE INDEX IF NOT EXISTS pusher_localpart_pushkey_idx ON pusher_pushers(localpart, pushkey); ` +const insertPusherSQL = "" + + "INSERT INTO pusher_pushers(localpart, pushkey, kind, app_id, app_display_name, device_display_name, profile_tag, lang, url, format) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)" + const selectPushersByLocalpartSQL = "" + "SELECT pushkey, kind, app_id, app_display_name, device_display_name, profile_tag, lang, url, format FROM pusher_pushers WHERE localpart = $1" const selectPusherByPushkeySQL = "" + "SELECT pushkey, kind, app_id, app_display_name, device_display_name, profile_tag, lang, url, format FROM pusher_pushers WHERE localpart = $1 AND pushkey = $2" + const deletePusherSQL = "" + "DELETE FROM pusher_pushers WHERE pushkey = $1 AND localpart = $2" + type pushersStatements struct { + insertPusherStmt *sql.Stmt selectPushersByLocalpartStmt *sql.Stmt selectPusherByPushkeyStmt *sql.Stmt deletePusherStmt *sql.Stmt @@ -79,6 +85,9 @@ func (s *pushersStatements) execSchema(db *sql.DB) error { } func (s *pushersStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + if s.insertPusherStmt, err = db.Prepare(insertPusherSQL); err != nil { + return + } if s.selectPushersByLocalpartStmt, err = db.Prepare(selectPushersByLocalpartSQL); err != nil { return } @@ -92,6 +101,17 @@ func (s *pushersStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN return } +// insertPusher creates a new pusher. +// Returns an error if the user already has a pusher with the given pusher pushkey. +// Returns nil error success. +func (s *pushersStatements) insertPusher( + ctx context.Context, txn *sql.Tx, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart string, +) error { + stmt := sqlutil.TxStmt(txn, s.insertPusherStmt) + _, err := stmt.ExecContext(ctx, localpart, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format) + return err +} + // deletePusher removes a single pusher by pushkey and user localpart. func (s *pushersStatements) deletePusher( ctx context.Context, txn *sql.Tx, pushkey, localpart string, diff --git a/userapi/storage/pushers/postgres/storage.go b/userapi/storage/pushers/postgres/storage.go index e16a85a7..3e18bc40 100644 --- a/userapi/storage/pushers/postgres/storage.go +++ b/userapi/storage/pushers/postgres/storage.go @@ -55,6 +55,13 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver return &Database{db, d}, nil } +func (d *Database) CreatePusher( + ctx context.Context, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart string, +) error { + return d.pushers.insertPusher(ctx, nil, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart) +} + +// GetPushersByLocalpart returns the pushers matching the given localpart. func (d *Database) GetPushersByLocalpart( ctx context.Context, localpart string, ) ([]api.Pusher, error) { diff --git a/userapi/storage/pushers/sqlite3/pushers_table.go b/userapi/storage/pushers/sqlite3/pushers_table.go index 622f85ce..f7af45c7 100644 --- a/userapi/storage/pushers/sqlite3/pushers_table.go +++ b/userapi/storage/pushers/sqlite3/pushers_table.go @@ -42,6 +42,10 @@ CREATE TABLE IF NOT EXISTS pusher_pushers ( UNIQUE (localpart, pushkey) ); ` +const insertPusherSQL = "" + + "INSERT INTO pusher_pushers (localpart, pushkey, kind, app_id, app_display_name, device_display_name, profile_tag, lang, url, format)" + + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9. $10)" + const selectPushersByLocalpartSQL = "" + "SELECT pushkey, kind, app_id, app_display_name, device_display_name, profile_tag, lang, url, format FROM pusher_pushers WHERE localpart = $1" @@ -54,6 +58,7 @@ const deletePusherSQL = "" + type pushersStatements struct { db *sql.DB writer sqlutil.Writer + insertPusherStmt *sql.Stmt selectPushersByLocalpartStmt *sql.Stmt selectPusherByPushkeyStmt *sql.Stmt deletePusherStmt *sql.Stmt @@ -68,6 +73,9 @@ func (s *pushersStatements) execSchema(db *sql.DB) error { func (s *pushersStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) { s.db = db s.writer = writer + if s.insertPusherStmt, err = db.Prepare(insertPusherSQL); err != nil { + return + } if s.selectPushersByLocalpartStmt, err = db.Prepare(selectPushersByLocalpartSQL); err != nil { return } @@ -81,6 +89,17 @@ func (s *pushersStatements) prepare(db *sql.DB, writer sqlutil.Writer, server go return } +// insertPusher creates a new pusher. +// Returns an error if the user already has a pusher with the given pusher pushkey. +// Returns nil error success. +func (s *pushersStatements) insertPusher( + ctx context.Context, txn *sql.Tx, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart string, +) error { + stmt := sqlutil.TxStmt(txn, s.insertPusherStmt) + _, err := stmt.ExecContext(ctx, localpart, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format) + return err +} + func (s *pushersStatements) selectPushersByLocalpart( ctx context.Context, txn *sql.Tx, localpart string, ) ([]api.Pusher, error) { diff --git a/userapi/storage/pushers/sqlite3/storage.go b/userapi/storage/pushers/sqlite3/storage.go index e89f9548..c8665190 100644 --- a/userapi/storage/pushers/sqlite3/storage.go +++ b/userapi/storage/pushers/sqlite3/storage.go @@ -57,6 +57,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver return &Database{db, writer, d}, nil } +func (d *Database) CreatePusher( + ctx context.Context, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart string, +) error { + return d.pushers.insertPusher(ctx, nil, pushkey, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, url, format, localpart) +} + // GetPushersByLocalpart returns the pushers matching the given localpart. func (d *Database) GetPushersByLocalpart( ctx context.Context, localpart string,