Implement Push Notifications (#1842)

* Add Pushserver component with Pushers API

Co-authored-by: Tommie Gannert <tommie@gannert.se>
Co-authored-by: Dan Peleg <dan@globekeeper.com>

* Wire Pushserver component

Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>

* Add PushGatewayClient.

The full event format is required for Sytest.

* Add a pushrules module.

* Change user API account creation to use the new pushrules module's defaults.

Introduces "scope" as required by client API, and some small field
tweaks to make some 61push Sytests pass.

* Add push rules query/put API in Pushserver.

This manipulates account data over User API, and fires sync messages
for changes. Those sync messages should, according to an existing TODO
in clientapi, be moved to userapi.

Forks clientapi/producers/syncapi.go to pushserver/ for later extension.

* Add clientapi routes for push rules to Pushserver.

A cleanup would be to move more of the name-splitting logic into
pushrules.go, to depollute routing.go.

* Output rooms.join.unread_notifications in /sync.

This is the read-side. Pushserver will be the write-side.

* Implement pushserver/storage for notifications.

* Use PushGatewayClient and the pushrules module in Pushserver's room consumer.

* Use one goroutine per user to avoid locking up the entire server for
  one bad push gateway.
* Split pushing by format.
* Send one device per push. Sytest does not support coalescing
  multiple devices into one push. Matches Synapse. Either we change
  Sytest, or remove the group-by-url-and-format logic.
* Write OutputNotificationData from push server. Sync API is already
  the consumer.

* Implement read receipt consumers in Pushserver.

Supports m.read and m.fully_read receipts.

* Add clientapi route for /unstable/notifications.

* Rename to UpsertPusher for clarity and handle pusher update

* Fix linter errors

* Ignore body.Close() error check

* Fix push server internal http wiring

* Add 40 newly passing 61push tests to whitelist

* Add next 12 newly passing 61push tests to whitelist

* Send notification data before notifying users in EDU server consumer

* NATS JetStream

* Goodbye sarama

* Fix `NewStreamTokenFromString`

* Consume on the correct topic for the roomserver

* Don't panic, NAK instead

* Move push notifications into the User API

* Don't set null values since that apparently causes Element upsetti

* Also set omitempty on conditions

* Fix bug so that we don't override the push rules unnecessarily

* Tweak defaults

* Update defaults

* More tweaks

* Move `/notifications` onto `r0`/`v3` mux

* User API will consume events and read/fully read markers from the sync API with stream positions, instead of consuming directly

Co-authored-by: Piotr Kozimor <p1996k@gmail.com>
Co-authored-by: Tommie Gannert <tommie@gannert.se>
Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
This commit is contained in:
Dan 2022-03-03 13:40:53 +02:00 committed by GitHub
parent 111f01ddc8
commit f05ce478f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
92 changed files with 5840 additions and 194 deletions

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/pushrules"
)
// UserInternalAPI is the internal API for information about users and devices.
@ -28,6 +29,7 @@ type UserInternalAPI interface {
LoginTokenInternalAPI
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
@ -37,6 +39,10 @@ type UserInternalAPI interface {
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error
PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error
PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error
QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse)
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
@ -45,6 +51,9 @@ type UserInternalAPI interface {
QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error
QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error
QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error
QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
}
type PerformKeyBackupRequest struct {
@ -424,3 +433,77 @@ const (
// AccountTypeAppService indicates this is an appservice account
AccountTypeAppService AccountType = 4
)
type QueryPushersRequest struct {
Localpart string
}
type QueryPushersResponse struct {
Pushers []Pusher `json:"pushers"`
}
type PerformPusherSetRequest struct {
Pusher // Anonymous field because that's how clientapi unmarshals it.
Localpart string
Append bool `json:"append"`
}
type PerformPusherDeletionRequest struct {
Localpart string
SessionID int64
}
// Pusher represents a push notification subscriber
type Pusher struct {
SessionID int64 `json:"session_id,omitempty"`
PushKey string `json:"pushkey"`
PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"`
Kind PusherKind `json:"kind"`
AppID string `json:"app_id"`
AppDisplayName string `json:"app_display_name"`
DeviceDisplayName string `json:"device_display_name"`
ProfileTag string `json:"profile_tag"`
Language string `json:"lang"`
Data map[string]interface{} `json:"data"`
}
type PusherKind string
const (
EmailKind PusherKind = "email"
HTTPKind PusherKind = "http"
)
type PerformPushRulesPutRequest struct {
UserID string `json:"user_id"`
RuleSets *pushrules.AccountRuleSets `json:"rule_sets"`
}
type QueryPushRulesRequest struct {
UserID string `json:"user_id"`
}
type QueryPushRulesResponse struct {
RuleSets *pushrules.AccountRuleSets `json:"rule_sets"`
}
type QueryNotificationsRequest struct {
Localpart string `json:"localpart"` // Required.
From string `json:"from,omitempty"`
Limit int `json:"limit,omitempty"`
Only string `json:"only,omitempty"`
}
type QueryNotificationsResponse struct {
NextToken string `json:"next_token"`
Notifications []*Notification `json:"notifications"` // Required.
}
type Notification struct {
Actions []*pushrules.Action `json:"actions"` // Required.
Event gomatrixserverlib.ClientEvent `json:"event"` // Required.
ProfileTag string `json:"profile_tag"` // Required by Sytest, but actually optional.
Read bool `json:"read"` // Required.
RoomID string `json:"room_id"` // Required.
TS gomatrixserverlib.Timestamp `json:"ts"` // Required.
}

View file

@ -79,6 +79,21 @@ func (t *UserInternalAPITrace) PerformKeyBackup(ctx context.Context, req *Perfor
util.GetLogger(ctx).Infof("PerformKeyBackup req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error {
err := t.Impl.PerformPusherSet(ctx, req, res)
util.GetLogger(ctx).Infof("PerformPusherSet req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error {
err := t.Impl.PerformPusherDeletion(ctx, req, res)
util.GetLogger(ctx).Infof("PerformPusherDeletion req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error {
err := t.Impl.PerformPushRulesPut(ctx, req, res)
util.GetLogger(ctx).Infof("PerformPushRulesPut req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) {
t.Impl.QueryKeyBackup(ctx, req, res)
util.GetLogger(ctx).Infof("QueryKeyBackup req=%+v res=%+v", js(req), js(res))
@ -118,6 +133,21 @@ func (t *UserInternalAPITrace) QueryOpenIDToken(ctx context.Context, req *QueryO
util.GetLogger(ctx).Infof("QueryOpenIDToken req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error {
err := t.Impl.QueryPushers(ctx, req, res)
util.GetLogger(ctx).Infof("QueryPushers req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error {
err := t.Impl.QueryPushRules(ctx, req, res)
util.GetLogger(ctx).Infof("QueryPushRules req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error {
err := t.Impl.QueryNotifications(ctx, req, res)
util.GetLogger(ctx).Infof("QueryNotifications req=%+v res=%+v", js(req), js(res))
return err
}
func js(thing interface{}) string {
b, err := json.Marshal(thing)

View file

@ -0,0 +1,136 @@
package consumers
import (
"context"
"encoding/json"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/types"
uapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/util"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
type OutputReadUpdateConsumer struct {
ctx context.Context
cfg *config.UserAPI
jetstream nats.JetStreamContext
durable string
db storage.Database
pgClient pushgateway.Client
ServerName gomatrixserverlib.ServerName
topic string
userAPI uapi.UserInternalAPI
syncProducer *producers.SyncAPI
}
func NewOutputReadUpdateConsumer(
process *process.ProcessContext,
cfg *config.UserAPI,
js nats.JetStreamContext,
store storage.Database,
pgClient pushgateway.Client,
userAPI uapi.UserInternalAPI,
syncProducer *producers.SyncAPI,
) *OutputReadUpdateConsumer {
return &OutputReadUpdateConsumer{
ctx: process.Context(),
cfg: cfg,
jetstream: js,
db: store,
ServerName: cfg.Matrix.ServerName,
durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIReadUpdateConsumer"),
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReadUpdate),
pgClient: pgClient,
userAPI: userAPI,
syncProducer: syncProducer,
}
}
func (s *OutputReadUpdateConsumer) Start() error {
if err := jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
); err != nil {
return err
}
return nil
}
func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
var read types.ReadUpdate
if err := json.Unmarshal(msg.Data, &read); err != nil {
log.WithError(err).Error("userapi clientapi consumer: message parse failure")
return true
}
if read.FullyRead == 0 && read.Read == 0 {
return true
}
userID := string(msg.Header.Get(jetstream.UserID))
roomID := string(msg.Header.Get(jetstream.RoomID))
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
log.WithError(err).Error("userapi clientapi consumer: SplitID failure")
return true
}
if domain != s.ServerName {
log.Error("userapi clientapi consumer: not a local user")
return true
}
log := log.WithFields(log.Fields{
"room_id": roomID,
"user_id": userID,
})
log.Tracef("Received read update from sync API: %#v", read)
if read.Read > 0 {
updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, int64(read.Read), true)
if err != nil {
log.WithError(err).Error("userapi EDU consumer")
return false
}
if updated {
if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil {
log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed")
return false
}
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
return false
}
}
}
if read.FullyRead > 0 {
deleted, err := s.db.DeleteNotificationsUpTo(ctx, localpart, roomID, int64(read.FullyRead))
if err != nil {
log.WithError(err).Errorf("userapi clientapi consumer: DeleteNotificationsUpTo failed")
return false
}
if deleted {
if err := util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
log.WithError(err).Error("userapi clientapi consumer: NotifyUserCounts failed")
return false
}
if err := s.syncProducer.GetAndSendNotificationData(ctx, userID, read.RoomID); err != nil {
log.WithError(err).Errorf("userapi clientapi consumer: GetAndSendNotificationData failed")
return false
}
}
}
return true
}

View file

@ -0,0 +1,588 @@
package consumers
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/internal/pushrules"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/userapi/util"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
type OutputStreamEventConsumer struct {
ctx context.Context
cfg *config.UserAPI
userAPI api.UserInternalAPI
rsAPI rsapi.RoomserverInternalAPI
jetstream nats.JetStreamContext
durable string
db storage.Database
topic string
pgClient pushgateway.Client
syncProducer *producers.SyncAPI
}
func NewOutputStreamEventConsumer(
process *process.ProcessContext,
cfg *config.UserAPI,
js nats.JetStreamContext,
store storage.Database,
pgClient pushgateway.Client,
userAPI api.UserInternalAPI,
rsAPI rsapi.RoomserverInternalAPI,
syncProducer *producers.SyncAPI,
) *OutputStreamEventConsumer {
return &OutputStreamEventConsumer{
ctx: process.Context(),
cfg: cfg,
jetstream: js,
db: store,
durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIStreamEventConsumer"),
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputStreamEvent),
pgClient: pgClient,
userAPI: userAPI,
rsAPI: rsAPI,
syncProducer: syncProducer,
}
}
func (s *OutputStreamEventConsumer) Start() error {
if err := jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
); err != nil {
return err
}
return nil
}
func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
var output types.StreamedEvent
output.Event = &gomatrixserverlib.HeaderedEvent{}
if err := json.Unmarshal(msg.Data, &output); err != nil {
log.WithError(err).Errorf("userapi consumer: message parse failure")
return true
}
if output.Event.Event == nil {
log.Errorf("userapi consumer: expected event")
return true
}
log.WithFields(log.Fields{
"event_id": output.Event.EventID(),
"event_type": output.Event.Type(),
"stream_pos": output.StreamPosition,
}).Tracef("Received message from sync API: %#v", output)
if err := s.processMessage(ctx, output.Event, int64(output.StreamPosition)); err != nil {
log.WithFields(log.Fields{
"event_id": output.Event.EventID(),
}).WithError(err).Errorf("userapi consumer: process room event failure")
}
return true
}
func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64) error {
members, roomSize, err := s.localRoomMembers(ctx, event.RoomID())
if err != nil {
return fmt.Errorf("s.localRoomMembers: %w", err)
}
if event.Type() == gomatrixserverlib.MRoomMember {
cevent := gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll)
var member *localMembership
member, err = newLocalMembership(&cevent)
if err != nil {
return fmt.Errorf("newLocalMembership: %w", err)
}
if member.Membership == gomatrixserverlib.Invite && member.Domain == s.cfg.Matrix.ServerName {
// localRoomMembers only adds joined members. An invite
// should also be pushed to the target user.
members = append(members, member)
}
}
// TODO: run in parallel with localRoomMembers.
roomName, err := s.roomName(ctx, event)
if err != nil {
return fmt.Errorf("s.roomName: %w", err)
}
log.WithFields(log.Fields{
"event_id": event.EventID(),
"room_id": event.RoomID(),
"num_members": len(members),
"room_size": roomSize,
}).Tracef("Notifying members")
// Notification.UserIsTarget is a per-member field, so we
// cannot group all users in a single request.
//
// TODO: does it have to be set? It's not required, and
// removing it means we can send all notifications to
// e.g. Element's Push gateway in one go.
for _, mem := range members {
if err := s.notifyLocal(ctx, event, pos, mem, roomSize, roomName); err != nil {
log.WithFields(log.Fields{
"localpart": mem.Localpart,
}).WithError(err).Debugf("Unable to push to local user")
continue
}
}
return nil
}
type localMembership struct {
gomatrixserverlib.MemberContent
UserID string
Localpart string
Domain gomatrixserverlib.ServerName
}
func newLocalMembership(event *gomatrixserverlib.ClientEvent) (*localMembership, error) {
if event.StateKey == nil {
return nil, fmt.Errorf("missing state_key")
}
var member localMembership
if err := json.Unmarshal(event.Content, &member.MemberContent); err != nil {
return nil, err
}
localpart, domain, err := gomatrixserverlib.SplitID('@', *event.StateKey)
if err != nil {
return nil, err
}
member.UserID = *event.StateKey
member.Localpart = localpart
member.Domain = domain
return &member, nil
}
// localRoomMembers fetches the current local members of a room, and
// the total number of members.
func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) {
req := &rsapi.QueryMembershipsForRoomRequest{
RoomID: roomID,
JoinedOnly: true,
}
var res rsapi.QueryMembershipsForRoomResponse
// XXX: This could potentially race if the state for the event is not known yet
// e.g. the event came over federation but we do not have the full state persisted.
if err := s.rsAPI.QueryMembershipsForRoom(ctx, req, &res); err != nil {
return nil, 0, err
}
var members []*localMembership
var ntotal int
for _, event := range res.JoinEvents {
member, err := newLocalMembership(&event)
if err != nil {
log.WithError(err).Errorf("Parsing MemberContent")
continue
}
if member.Membership != gomatrixserverlib.Join {
continue
}
if member.Domain != s.cfg.Matrix.ServerName {
continue
}
ntotal++
members = append(members, member)
}
return members, ntotal, nil
}
// roomName returns the name in the event (if type==m.room.name), or
// looks it up in roomserver. If there is no name,
// m.room.canonical_alias is consulted. Returns an empty string if the
// room has no name.
func (s *OutputStreamEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) {
if event.Type() == gomatrixserverlib.MRoomName {
name, err := unmarshalRoomName(event)
if err != nil {
return "", err
}
if name != "" {
return name, nil
}
}
req := &rsapi.QueryCurrentStateRequest{
RoomID: event.RoomID(),
StateTuples: []gomatrixserverlib.StateKeyTuple{roomNameTuple, canonicalAliasTuple},
}
var res rsapi.QueryCurrentStateResponse
if err := s.rsAPI.QueryCurrentState(ctx, req, &res); err != nil {
return "", nil
}
if eventS := res.StateEvents[roomNameTuple]; eventS != nil {
return unmarshalRoomName(eventS)
}
if event.Type() == gomatrixserverlib.MRoomCanonicalAlias {
alias, err := unmarshalCanonicalAlias(event)
if err != nil {
return "", err
}
if alias != "" {
return alias, nil
}
}
if event = res.StateEvents[canonicalAliasTuple]; event != nil {
return unmarshalCanonicalAlias(event)
}
return "", nil
}
var (
canonicalAliasTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias}
roomNameTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomName}
)
func unmarshalRoomName(event *gomatrixserverlib.HeaderedEvent) (string, error) {
var nc eventutil.NameContent
if err := json.Unmarshal(event.Content(), &nc); err != nil {
return "", fmt.Errorf("unmarshaling NameContent: %w", err)
}
return nc.Name, nil
}
func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, error) {
var cac eventutil.CanonicalAliasContent
if err := json.Unmarshal(event.Content(), &cac); err != nil {
return "", fmt.Errorf("unmarshaling CanonicalAliasContent: %w", err)
}
return cac.Alias, nil
}
// notifyLocal finds the right push actions for a local user, given an event.
func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64, mem *localMembership, roomSize int, roomName string) error {
actions, err := s.evaluatePushRules(ctx, event, mem, roomSize)
if err != nil {
return err
}
a, tweaks, err := pushrules.ActionsToTweaks(actions)
if err != nil {
return err
}
// TODO: support coalescing.
if a != pushrules.NotifyAction && a != pushrules.CoalesceAction {
log.WithFields(log.Fields{
"event_id": event.EventID(),
"room_id": event.RoomID(),
"localpart": mem.Localpart,
}).Tracef("Push rule evaluation rejected the event")
return nil
}
devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks)
if err != nil {
return err
}
n := &api.Notification{
Actions: actions,
// UNSPEC: the spec doesn't say this is a ClientEvent, but the
// fields seem to match. room_id should be missing, which
// matches the behaviour of FormatSync.
Event: gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatSync),
// TODO: this is per-device, but it's not part of the primary
// key. So inserting one notification per profile tag doesn't
// make sense. What is this supposed to be? Sytests require it
// to "work", but they only use a single device.
ProfileTag: profileTag,
RoomID: event.RoomID(),
TS: gomatrixserverlib.AsTimestamp(time.Now()),
}
if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), pos, tweaks, n); err != nil {
return err
}
if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil {
return err
}
// We do this after InsertNotification. Thus, this should always return >=1.
userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications)
if err != nil {
return err
}
log.WithFields(log.Fields{
"event_id": event.EventID(),
"room_id": event.RoomID(),
"localpart": mem.Localpart,
"num_urls": len(devicesByURLAndFormat),
"num_unread": userNumUnreadNotifs,
}).Tracef("Notifying single member")
// Push gateways are out of our control, and we cannot risk
// looking up the server on a misbehaving push gateway. Each user
// receives a goroutine now that all internal API calls have been
// made.
//
// TODO: think about bounding this to one per user, and what
// ordering guarantees we must provide.
go func() {
// This background processing cannot be tied to a request.
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
var rejected []*pushgateway.Device
for url, fmts := range devicesByURLAndFormat {
for format, devices := range fmts {
// TODO: support "email".
if !strings.HasPrefix(url, "http") {
continue
}
// UNSPEC: the specification suggests there can be
// more than one device per request. There is at least
// one Sytest that expects one HTTP request per
// device, rather than per URL. For now, we must
// notify each one separately.
for _, dev := range devices {
rej, err := s.notifyHTTP(ctx, event, url, format, []*pushgateway.Device{dev}, mem.Localpart, roomName, int(userNumUnreadNotifs))
if err != nil {
log.WithFields(log.Fields{
"event_id": event.EventID(),
"localpart": mem.Localpart,
}).WithError(err).Errorf("Unable to notify HTTP pusher")
continue
}
rejected = append(rejected, rej...)
}
}
}
if len(rejected) > 0 {
s.deleteRejectedPushers(ctx, rejected, mem.Localpart)
}
}()
return nil
}
// evaluatePushRules fetches and evaluates the push rules of a local
// user. Returns actions (including dont_notify).
func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
if event.Sender() == mem.UserID {
// SPEC: Homeservers MUST NOT notify the Push Gateway for
// events that the user has sent themselves.
return nil, nil
}
var res api.QueryPushRulesResponse
if err := s.userAPI.QueryPushRules(ctx, &api.QueryPushRulesRequest{UserID: mem.UserID}, &res); err != nil {
return nil, err
}
ec := &ruleSetEvalContext{
ctx: ctx,
rsAPI: s.rsAPI,
mem: mem,
roomID: event.RoomID(),
roomSize: roomSize,
}
eval := pushrules.NewRuleSetEvaluator(ec, &res.RuleSets.Global)
rule, err := eval.MatchEvent(event.Event)
if err != nil {
return nil, err
}
if rule == nil {
// SPEC: If no rules match an event, the homeserver MUST NOT
// notify the Push Gateway for that event.
return nil, err
}
log.WithFields(log.Fields{
"event_id": event.EventID(),
"room_id": event.RoomID(),
"localpart": mem.Localpart,
"rule_id": rule.RuleID,
}).Tracef("Matched a push rule")
return rule.Actions, nil
}
type ruleSetEvalContext struct {
ctx context.Context
rsAPI rsapi.RoomserverInternalAPI
mem *localMembership
roomID string
roomSize int
}
func (rse *ruleSetEvalContext) UserDisplayName() string { return rse.mem.DisplayName }
func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil }
func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, error) {
req := &rsapi.QueryLatestEventsAndStateRequest{
RoomID: rse.roomID,
StateToFetch: []gomatrixserverlib.StateKeyTuple{
{EventType: gomatrixserverlib.MRoomPowerLevels},
},
}
var res rsapi.QueryLatestEventsAndStateResponse
if err := rse.rsAPI.QueryLatestEventsAndState(rse.ctx, req, &res); err != nil {
return false, err
}
for _, ev := range res.StateEvents {
if ev.Type() != gomatrixserverlib.MRoomPowerLevels {
continue
}
plc, err := gomatrixserverlib.NewPowerLevelContentFromEvent(ev.Event)
if err != nil {
return false, err
}
return plc.UserLevel(userID) >= plc.NotificationLevel(levelKey), nil
}
return true, nil
}
// localPushDevices pushes to the configured devices of a local
// user. The map keys are [url][format].
func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db)
if err != nil {
return nil, "", err
}
var profileTag string
devicesByURL := make(map[string]map[string][]*pushgateway.Device, len(pusherDevices))
for _, pusherDevice := range pusherDevices {
if profileTag == "" {
profileTag = pusherDevice.Pusher.ProfileTag
}
url := pusherDevice.URL
if devicesByURL[url] == nil {
devicesByURL[url] = make(map[string][]*pushgateway.Device, 2)
}
devicesByURL[url][pusherDevice.Format] = append(devicesByURL[url][pusherDevice.Format], &pusherDevice.Device)
}
return devicesByURL, profileTag, nil
}
// notifyHTTP performs a notificatation to a Push Gateway.
func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) {
logger := log.WithFields(log.Fields{
"event_id": event.EventID(),
"url": url,
"localpart": localpart,
"num_devices": len(devices),
})
var req pushgateway.NotifyRequest
switch format {
case "event_id_only":
req = pushgateway.NotifyRequest{
Notification: pushgateway.Notification{
Counts: &pushgateway.Counts{},
Devices: devices,
EventID: event.EventID(),
RoomID: event.RoomID(),
},
}
default:
req = pushgateway.NotifyRequest{
Notification: pushgateway.Notification{
Content: event.Content(),
Counts: &pushgateway.Counts{
Unread: userNumUnreadNotifs,
},
Devices: devices,
EventID: event.EventID(),
ID: event.EventID(),
RoomID: event.RoomID(),
RoomName: roomName,
Sender: event.Sender(),
Type: event.Type(),
},
}
if mem, err := event.Membership(); err == nil {
req.Notification.Membership = mem
}
if event.StateKey() != nil && *event.StateKey() == fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName) {
req.Notification.UserIsTarget = true
}
}
logger.Debugf("Notifying push gateway %s", url)
var res pushgateway.NotifyResponse
if err := s.pgClient.Notify(ctx, url, &req, &res); err != nil {
logger.WithError(err).Errorf("Failed to notify push gateway %s", url)
return nil, err
}
logger.WithField("num_rejected", len(res.Rejected)).Tracef("Push gateway result")
if len(res.Rejected) == 0 {
return nil, nil
}
devMap := make(map[string]*pushgateway.Device, len(devices))
for _, d := range devices {
devMap[d.PushKey] = d
}
rejected := make([]*pushgateway.Device, 0, len(res.Rejected))
for _, pushKey := range res.Rejected {
d := devMap[pushKey]
if d != nil {
rejected = append(rejected, d)
}
}
return rejected, nil
}
// deleteRejectedPushers deletes the pushers associated with the given devices.
func (s *OutputStreamEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) {
log.WithFields(log.Fields{
"localpart": localpart,
"app_id0": devices[0].AppID,
"num_devices": len(devices),
}).Warnf("Deleting pushers rejected by the HTTP push gateway")
for _, d := range devices {
if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil {
log.WithFields(log.Fields{
"localpart": localpart,
}).WithError(err).Errorf("Unable to delete rejected pusher")
}
}
}

View file

@ -20,6 +20,8 @@ import (
"encoding/json"
"errors"
"fmt"
"strconv"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
@ -27,16 +29,22 @@ import (
"github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/internal/sqlutil"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
type UserInternalAPI struct {
DB storage.Database
ServerName gomatrixserverlib.ServerName
DB storage.Database
SyncProducer *producers.SyncAPI
DisableTLSValidation bool
ServerName gomatrixserverlib.ServerName
// AppServices is the list of all registered AS
AppServices []config.ApplicationService
KeyAPI keyapi.KeyInternalAPI
@ -595,3 +603,162 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
}
res.Keys = result
}
func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error {
if req.Limit == 0 || req.Limit > 1000 {
req.Limit = 1000
}
var fromID int64
var err error
if req.From != "" {
fromID, err = strconv.ParseInt(req.From, 10, 64)
if err != nil {
return fmt.Errorf("QueryNotifications: parsing 'from': %w", err)
}
}
var filter tables.NotificationFilter = tables.AllNotifications
if req.Only == "highlight" {
filter = tables.HighlightNotifications
}
notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, fromID, req.Limit, filter)
if err != nil {
return err
}
if notifs == nil {
// This ensures empty is JSON-encoded as [] instead of null.
notifs = []*api.Notification{}
}
res.Notifications = notifs
if lastID >= 0 {
res.NextToken = strconv.FormatInt(lastID+1, 10)
}
return nil
}
func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.PerformPusherSetRequest, res *struct{}) error {
util.GetLogger(ctx).WithFields(logrus.Fields{
"localpart": req.Localpart,
"pushkey": req.Pusher.PushKey,
"display_name": req.Pusher.AppDisplayName,
}).Info("PerformPusherCreation")
if !req.Append {
err := a.DB.RemovePushers(ctx, req.Pusher.AppID, req.Pusher.PushKey)
if err != nil {
return err
}
}
if req.Pusher.Kind == "" {
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart)
}
if req.Pusher.PushKeyTS == 0 {
req.Pusher.PushKeyTS = gomatrixserverlib.AsTimestamp(time.Now())
}
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart)
}
func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
pushers, err := a.DB.GetPushers(ctx, req.Localpart)
if err != nil {
return err
}
for i := range pushers {
logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID)
if pushers[i].SessionID != req.SessionID {
err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart)
if err != nil {
return err
}
}
}
return nil
}
func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
var err error
res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart)
return err
}
func (a *UserInternalAPI) PerformPushRulesPut(
ctx context.Context,
req *api.PerformPushRulesPutRequest,
_ *struct{},
) error {
bs, err := json.Marshal(&req.RuleSets)
if err != nil {
return err
}
userReq := api.InputAccountDataRequest{
UserID: req.UserID,
DataType: pushRulesAccountDataType,
AccountData: json.RawMessage(bs),
}
var userRes api.InputAccountDataResponse // empty
if err := a.InputAccountData(ctx, &userReq, &userRes); err != nil {
return err
}
if err := a.SyncProducer.SendAccountData(req.UserID, "" /* roomID */, pushRulesAccountDataType); err != nil {
util.GetLogger(ctx).WithError(err).Errorf("syncProducer.SendData failed")
}
return nil
}
func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
userReq := api.QueryAccountDataRequest{
UserID: req.UserID,
DataType: pushRulesAccountDataType,
}
var userRes api.QueryAccountDataResponse
if err := a.QueryAccountData(ctx, &userReq, &userRes); err != nil {
return err
}
bs, ok := userRes.GlobalAccountData[pushRulesAccountDataType]
if ok {
// Legacy Dendrite users will have completely empty push rules, so we should
// detect that situation and set some defaults.
var rules struct {
G struct {
Content []json.RawMessage `json:"content"`
Override []json.RawMessage `json:"override"`
Room []json.RawMessage `json:"room"`
Sender []json.RawMessage `json:"sender"`
Underride []json.RawMessage `json:"underride"`
} `json:"global"`
}
if err := json.Unmarshal([]byte(bs), &rules); err == nil {
count := len(rules.G.Content) + len(rules.G.Override) +
len(rules.G.Room) + len(rules.G.Sender) + len(rules.G.Underride)
ok = count > 0
}
}
if !ok {
// If we didn't find any default push rules then we should just generate some
// fresh ones.
localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
return fmt.Errorf("failed to split user ID %q for push rules", req.UserID)
}
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, a.ServerName)
prbs, err := json.Marshal(pushRuleSets)
if err != nil {
return fmt.Errorf("failed to marshal default push rules: %w", err)
}
if err := a.DB.SaveAccountData(ctx, localpart, "", pushRulesAccountDataType, json.RawMessage(prbs)); err != nil {
return fmt.Errorf("failed to save default push rules: %w", err)
}
res.RuleSets = pushRuleSets
return nil
}
var data pushrules.AccountRuleSets
if err := json.Unmarshal([]byte(bs), &data); err != nil {
util.GetLogger(ctx).WithError(err).Error("json.Unmarshal of push rules failed")
return err
}
res.RuleSets = &data
return nil
}
const pushRulesAccountDataType = "m.push_rules"

View file

@ -37,6 +37,9 @@ const (
PerformAccountDeactivationPath = "/userapi/performAccountDeactivation"
PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation"
PerformKeyBackupPath = "/userapi/performKeyBackup"
PerformPusherSetPath = "/pushserver/performPusherSet"
PerformPusherDeletionPath = "/pushserver/performPusherDeletion"
PerformPushRulesPutPath = "/pushserver/performPushRulesPut"
QueryKeyBackupPath = "/userapi/queryKeyBackup"
QueryProfilePath = "/userapi/queryProfile"
@ -46,6 +49,9 @@ const (
QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
QuerySearchProfilesPath = "/userapi/querySearchProfiles"
QueryOpenIDTokenPath = "/userapi/queryOpenIDToken"
QueryPushersPath = "/pushserver/queryPushers"
QueryPushRulesPath = "/pushserver/queryPushRules"
QueryNotificationsPath = "/pushserver/queryNotifications"
)
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@ -249,3 +255,58 @@ func (h *httpUserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.Query
res.Error = err.Error()
}
}
func (h *httpUserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryNotifications")
defer span.Finish()
return httputil.PostJSON(ctx, span, h.httpClient, h.apiURL+QueryNotificationsPath, req, res)
}
func (h *httpUserInternalAPI) PerformPusherSet(
ctx context.Context,
request *api.PerformPusherSetRequest,
response *struct{},
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherSet")
defer span.Finish()
apiURL := h.apiURL + PerformPusherSetPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherDeletion")
defer span.Finish()
apiURL := h.apiURL + PerformPusherDeletionPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushers")
defer span.Finish()
apiURL := h.apiURL + QueryPushersPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) PerformPushRulesPut(
ctx context.Context,
request *api.PerformPushRulesPutRequest,
response *struct{},
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPushRulesPut")
defer span.Finish()
apiURL := h.apiURL + PerformPushRulesPutPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushRules")
defer span.Finish()
apiURL := h.apiURL + QueryPushRulesPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}

View file

@ -265,4 +265,86 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryNotificationsPath,
httputil.MakeInternalAPI("queryNotifications", func(req *http.Request) util.JSONResponse {
var request api.QueryNotificationsRequest
var response api.QueryNotificationsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryNotifications(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformPusherSetPath,
httputil.MakeInternalAPI("performPusherSet", func(req *http.Request) util.JSONResponse {
request := api.PerformPusherSetRequest{}
response := struct{}{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformPusherSet(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformPusherDeletionPath,
httputil.MakeInternalAPI("performPusherDeletion", func(req *http.Request) util.JSONResponse {
request := api.PerformPusherDeletionRequest{}
response := struct{}{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformPusherDeletion(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryPushersPath,
httputil.MakeInternalAPI("queryPushers", func(req *http.Request) util.JSONResponse {
request := api.QueryPushersRequest{}
response := api.QueryPushersResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryPushers(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformPushRulesPutPath,
httputil.MakeInternalAPI("performPushRulesPut", func(req *http.Request) util.JSONResponse {
request := api.PerformPushRulesPutRequest{}
response := struct{}{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformPushRulesPut(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryPushRulesPath,
httputil.MakeInternalAPI("queryPushRules", func(req *http.Request) util.JSONResponse {
request := api.QueryPushRulesRequest{}
response := api.QueryPushRulesResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryPushRules(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
}

View file

@ -0,0 +1,104 @@
package producers
import (
"context"
"encoding/json"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
type JetStreamPublisher interface {
PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, error)
}
// SyncAPI produces messages for the Sync API server to consume.
type SyncAPI struct {
db storage.Database
producer JetStreamPublisher
clientDataTopic string
notificationDataTopic string
}
func NewSyncAPI(db storage.Database, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI {
return &SyncAPI{
db: db,
producer: js,
clientDataTopic: clientDataTopic,
notificationDataTopic: notificationDataTopic,
}
}
// SendAccountData sends account data to the Sync API server.
func (p *SyncAPI) SendAccountData(userID string, roomID string, dataType string) error {
m := &nats.Msg{
Subject: p.clientDataTopic,
Header: nats.Header{},
}
m.Header.Set(jetstream.UserID, userID)
var err error
m.Data, err = json.Marshal(eventutil.AccountData{
RoomID: roomID,
Type: dataType,
})
if err != nil {
return err
}
log.WithFields(log.Fields{
"user_id": userID,
"room_id": roomID,
"data_type": dataType,
}).Tracef("Producing to topic '%s'", p.clientDataTopic)
_, err = p.producer.PublishMsg(m)
return err
}
// GetAndSendNotificationData reads the database and sends data about unread
// notifications to the Sync API server.
func (p *SyncAPI) GetAndSendNotificationData(ctx context.Context, userID, roomID string) error {
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return err
}
ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, roomID)
if err != nil {
return err
}
return p.sendNotificationData(userID, &eventutil.NotificationData{
RoomID: roomID,
UnreadHighlightCount: int(nhighlight),
UnreadNotificationCount: int(ntotal),
})
}
// sendNotificationData sends data about unread notifications to the Sync API server.
func (p *SyncAPI) sendNotificationData(userID string, data *eventutil.NotificationData) error {
m := &nats.Msg{
Subject: p.notificationDataTopic,
Header: nats.Header{},
}
m.Header.Set(jetstream.UserID, userID)
var err error
m.Data, err = json.Marshal(data)
if err != nil {
return err
}
log.WithFields(log.Fields{
"user_id": userID,
"room_id": data.RoomID,
}).Tracef("Producing to topic '%s'", p.clientDataTopic)
_, err = p.producer.PublishMsg(m)
return err
}

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
type Database interface {
@ -89,6 +90,18 @@ type Database interface {
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error)
SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error)
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)
UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error
GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error)
RemovePusher(ctx context.Context, appid, pushkey, localpart string) error
RemovePushers(ctx context.Context, appid, pushkey string) error
}
// Err3PIDInUse is the error returned when trying to save an association involving

View file

@ -0,0 +1,219 @@
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
type notificationsStatements struct {
insertStmt *sql.Stmt
deleteUpToStmt *sql.Stmt
updateReadStmt *sql.Stmt
selectStmt *sql.Stmt
selectCountStmt *sql.Stmt
selectRoomCountsStmt *sql.Stmt
}
const notificationSchema = `
CREATE TABLE IF NOT EXISTS userapi_notifications (
id BIGSERIAL PRIMARY KEY,
localpart TEXT NOT NULL,
room_id TEXT NOT NULL,
event_id TEXT NOT NULL,
stream_pos BIGINT NOT NULL,
ts_ms BIGINT NOT NULL,
highlight BOOLEAN NOT NULL,
notification_json TEXT NOT NULL,
read BOOLEAN NOT NULL DEFAULT FALSE
);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
`
const insertNotificationSQL = "" +
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
const deleteNotificationsUpToSQL = "" +
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
const updateNotificationReadSQL = "" +
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
const selectNotificationSQL = "" +
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
") AND NOT read ORDER BY localpart, id LIMIT $4"
const selectNotificationCountSQL = "" +
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
"(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
") AND NOT read"
const selectRoomNotificationCountsSQL = "" +
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
"WHERE localpart = $1 AND room_id = $2 AND NOT read"
func NewPostgresNotificationTable(db *sql.DB) (tables.NotificationTable, error) {
s := &notificationsStatements{}
_, err := db.Exec(notificationSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertStmt, insertNotificationSQL},
{&s.deleteUpToStmt, deleteNotificationsUpToSQL},
{&s.updateReadStmt, updateNotificationReadSQL},
{&s.selectStmt, selectNotificationSQL},
{&s.selectCountStmt, selectNotificationCountSQL},
{&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL},
}.Prepare(db)
}
// Insert inserts a notification into the database.
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS
nn := *n
// Clears out fields that have their own columns to (1) shrink the
// data and (2) avoid difficult-to-debug inconsistency bugs.
nn.RoomID = ""
nn.TS, nn.Read = 0, false
bs, err := json.Marshal(nn)
if err != nil {
return err
}
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
return err
}
// DeleteUpTo deletes all previous notifications, up to and including the event.
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
if err != nil {
return false, err
}
nrows, err := res.RowsAffected()
if err != nil {
return true, err
}
log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows)
return nrows > 0, nil
}
// UpdateRead updates the "read" value for an event.
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
if err != nil {
return false, err
}
nrows, err := res.RowsAffected()
if err != nil {
return true, err
}
log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows)
return nrows > 0, nil
}
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
if err != nil {
return nil, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
var maxID int64 = -1
var notifs []*api.Notification
for rows.Next() {
var id int64
var roomID string
var ts gomatrixserverlib.Timestamp
var read bool
var jsonStr string
err = rows.Scan(
&id,
&roomID,
&ts,
&read,
&jsonStr)
if err != nil {
return nil, 0, err
}
var n api.Notification
err := json.Unmarshal([]byte(jsonStr), &n)
if err != nil {
return nil, 0, err
}
n.RoomID = roomID
n.TS = ts
n.Read = read
notifs = append(notifs, &n)
if maxID < id {
maxID = id
}
}
return notifs, maxID, rows.Err()
}
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) {
rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter))
if err != nil {
return 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var count int64
if err := rows.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
return 0, rows.Err()
}
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) {
rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID)
if err != nil {
return 0, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var total, highlight int64
if err := rows.Scan(&total, &highlight); err != nil {
return 0, 0, err
}
return total, highlight, nil
}
return 0, 0, rows.Err()
}

View file

@ -0,0 +1,157 @@
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
const pushersSchema = `
CREATE TABLE IF NOT EXISTS userapi_pushers (
id BIGSERIAL PRIMARY KEY,
-- The Matrix user ID localpart for this pusher
localpart TEXT NOT NULL,
session_id BIGINT DEFAULT NULL,
profile_tag TEXT,
kind TEXT NOT NULL,
app_id TEXT NOT NULL,
app_display_name TEXT NOT NULL,
device_display_name TEXT NOT NULL,
pushkey TEXT NOT NULL,
pushkey_ts_ms BIGINT NOT NULL DEFAULT 0,
lang TEXT NOT NULL,
data TEXT NOT NULL
);
-- For faster deleting by app_id, pushkey pair.
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
-- For faster retrieving by localpart.
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
-- Pushkey must be unique for a given user and app.
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
`
const insertPusherSQL = "" +
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
const selectPushersSQL = "" +
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
const deletePusherSQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
const deletePushersByAppIdAndPushKeySQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
func NewPostgresPusherTable(db *sql.DB) (tables.PusherTable, error) {
s := &pushersStatements{}
_, err := db.Exec(pushersSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertPusherStmt, insertPusherSQL},
{&s.selectPushersStmt, selectPushersSQL},
{&s.deletePusherStmt, deletePusherSQL},
{&s.deletePushersByAppIdAndPushKeyStmt, deletePushersByAppIdAndPushKeySQL},
}.Prepare(db)
}
type pushersStatements struct {
insertPusherStmt *sql.Stmt
selectPushersStmt *sql.Stmt
deletePusherStmt *sql.Stmt
deletePushersByAppIdAndPushKeyStmt *sql.Stmt
}
// insertPusher creates a new pusher.
// Returns an error if the user already has a pusher with the given pusher pushkey.
// Returns nil error success.
func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64,
pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
) error {
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
logrus.Debugf("Created pusher %d", session_id)
return err
}
func (s *pushersStatements) SelectPushers(
ctx context.Context, txn *sql.Tx, localpart string,
) ([]api.Pusher, error) {
pushers := []api.Pusher{}
rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart)
if err != nil {
return pushers, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectPushers: rows.close() failed")
for rows.Next() {
var pusher api.Pusher
var data []byte
err = rows.Scan(
&pusher.SessionID,
&pusher.PushKey,
&pusher.PushKeyTS,
&pusher.Kind,
&pusher.AppID,
&pusher.AppDisplayName,
&pusher.DeviceDisplayName,
&pusher.ProfileTag,
&pusher.Language,
&data)
if err != nil {
return pushers, err
}
err := json.Unmarshal(data, &pusher.Data)
if err != nil {
return pushers, err
}
pushers = append(pushers, pusher)
}
logrus.Debugf("Database returned %d pushers", len(pushers))
return pushers, rows.Err()
}
// deletePusher removes a single pusher by pushkey and user localpart.
func (s *pushersStatements) DeletePusher(
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
) error {
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
return err
}
func (s *pushersStatements) DeletePushers(
ctx context.Context, txn *sql.Tx, appid, pushkey string,
) error {
_, err := sqlutil.TxStmt(txn, s.deletePushersByAppIdAndPushKeyStmt).ExecContext(ctx, appid, pushkey)
return err
}

View file

@ -85,6 +85,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil {
return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err)
}
pusherTable, err := NewPostgresPusherTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresPusherTable: %w", err)
}
notificationsTable, err := NewPostgresNotificationTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err)
}
return &shared.Database{
AccountDatas: accountDataTable,
Accounts: accountsTable,
@ -95,6 +103,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
OpenIDTokens: openIDTable,
Profiles: profilesTable,
ThreePIDs: threePIDTable,
Pushers: pusherTable,
Notifications: notificationsTable,
ServerName: serverName,
DB: db,
Writer: sqlutil.NewDummyWriter(),

View file

@ -29,6 +29,7 @@ import (
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
@ -47,6 +48,8 @@ type Database struct {
KeyBackupVersions tables.KeyBackupVersionTable
Devices tables.DevicesTable
LoginTokens tables.LoginTokenTable
Notifications tables.NotificationTable
Pushers tables.PusherTable
LoginTokenLifetime time.Duration
ServerName gomatrixserverlib.ServerName
BcryptCost int
@ -160,15 +163,12 @@ func (d *Database) createAccount(
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
return nil, err
}
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": {
"content": [],
"override": [],
"room": [],
"sender": [],
"underride": []
}
}`)); err != nil {
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
prbs, err := json.Marshal(pushRuleSets)
if err != nil {
return nil, err
}
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
return nil, err
}
return account, nil
@ -670,3 +670,94 @@ func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.LoginTokens.SelectLoginToken(ctx, token)
}
func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
})
}
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos)
return err
})
return
}
func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b)
return err
})
return
}
func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
return d.Notifications.Select(ctx, nil, localpart, fromID, limit, filter)
}
func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) {
return d.Notifications.SelectCount(ctx, nil, localpart, filter)
}
func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) {
return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID)
}
func (d *Database) UpsertPusher(
ctx context.Context, p api.Pusher, localpart string,
) error {
data, err := json.Marshal(p.Data)
if err != nil {
return err
}
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Pushers.InsertPusher(
ctx, txn,
p.SessionID,
p.PushKey,
p.PushKeyTS,
p.Kind,
p.AppID,
p.AppDisplayName,
p.DeviceDisplayName,
p.ProfileTag,
p.Language,
string(data),
localpart)
})
}
// GetPushers returns the pushers matching the given localpart.
func (d *Database) GetPushers(
ctx context.Context, localpart string,
) ([]api.Pusher, error) {
return d.Pushers.SelectPushers(ctx, nil, localpart)
}
// RemovePusher deletes one pusher
// Invoked when `append` is true and `kind` is null in
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
func (d *Database) RemovePusher(
ctx context.Context, appid, pushkey, localpart string,
) error {
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart)
if err == sql.ErrNoRows {
return nil
}
return err
})
}
// RemovePushers deletes all pushers that match given App Id and Push Key pair.
// Invoked when `append` parameter is false in
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
func (d *Database) RemovePushers(
ctx context.Context, appid, pushkey string,
) error {
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
return d.Pushers.DeletePushers(ctx, txn, appid, pushkey)
})
}

View file

@ -0,0 +1,219 @@
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
type notificationsStatements struct {
insertStmt *sql.Stmt
deleteUpToStmt *sql.Stmt
updateReadStmt *sql.Stmt
selectStmt *sql.Stmt
selectCountStmt *sql.Stmt
selectRoomCountsStmt *sql.Stmt
}
const notificationSchema = `
CREATE TABLE IF NOT EXISTS userapi_notifications (
id INTEGER PRIMARY KEY AUTOINCREMENT,
localpart TEXT NOT NULL,
room_id TEXT NOT NULL,
event_id TEXT NOT NULL,
stream_pos BIGINT NOT NULL,
ts_ms BIGINT NOT NULL,
highlight BOOLEAN NOT NULL,
notification_json TEXT NOT NULL,
read BOOLEAN NOT NULL DEFAULT FALSE
);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
`
const insertNotificationSQL = "" +
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
const deleteNotificationsUpToSQL = "" +
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
const updateNotificationReadSQL = "" +
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
const selectNotificationSQL = "" +
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
") AND NOT read ORDER BY localpart, id LIMIT $4"
const selectNotificationCountSQL = "" +
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
"(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
") AND NOT read"
const selectRoomNotificationCountsSQL = "" +
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
"WHERE localpart = $1 AND room_id = $2 AND NOT read"
func NewSQLiteNotificationTable(db *sql.DB) (tables.NotificationTable, error) {
s := &notificationsStatements{}
_, err := db.Exec(notificationSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertStmt, insertNotificationSQL},
{&s.deleteUpToStmt, deleteNotificationsUpToSQL},
{&s.updateReadStmt, updateNotificationReadSQL},
{&s.selectStmt, selectNotificationSQL},
{&s.selectCountStmt, selectNotificationCountSQL},
{&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL},
}.Prepare(db)
}
// Insert inserts a notification into the database.
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS
nn := *n
// Clears out fields that have their own columns to (1) shrink the
// data and (2) avoid difficult-to-debug inconsistency bugs.
nn.RoomID = ""
nn.TS, nn.Read = 0, false
bs, err := json.Marshal(nn)
if err != nil {
return err
}
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
return err
}
// DeleteUpTo deletes all previous notifications, up to and including the event.
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
if err != nil {
return false, err
}
nrows, err := res.RowsAffected()
if err != nil {
return true, err
}
log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows)
return nrows > 0, nil
}
// UpdateRead updates the "read" value for an event.
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
if err != nil {
return false, err
}
nrows, err := res.RowsAffected()
if err != nil {
return true, err
}
log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows)
return nrows > 0, nil
}
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
if err != nil {
return nil, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
var maxID int64 = -1
var notifs []*api.Notification
for rows.Next() {
var id int64
var roomID string
var ts gomatrixserverlib.Timestamp
var read bool
var jsonStr string
err = rows.Scan(
&id,
&roomID,
&ts,
&read,
&jsonStr)
if err != nil {
return nil, 0, err
}
var n api.Notification
err := json.Unmarshal([]byte(jsonStr), &n)
if err != nil {
return nil, 0, err
}
n.RoomID = roomID
n.TS = ts
n.Read = read
notifs = append(notifs, &n)
if maxID < id {
maxID = id
}
}
return notifs, maxID, rows.Err()
}
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) {
rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter))
if err != nil {
return 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var count int64
if err := rows.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
return 0, rows.Err()
}
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) {
rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID)
if err != nil {
return 0, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var total, highlight int64
if err := rows.Scan(&total, &highlight); err != nil {
return 0, 0, err
}
return total, highlight, nil
}
return 0, 0, rows.Err()
}

View file

@ -0,0 +1,157 @@
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
const pushersSchema = `
CREATE TABLE IF NOT EXISTS userapi_pushers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- The Matrix user ID localpart for this pusher
localpart TEXT NOT NULL,
session_id BIGINT DEFAULT NULL,
profile_tag TEXT,
kind TEXT NOT NULL,
app_id TEXT NOT NULL,
app_display_name TEXT NOT NULL,
device_display_name TEXT NOT NULL,
pushkey TEXT NOT NULL,
pushkey_ts_ms BIGINT NOT NULL DEFAULT 0,
lang TEXT NOT NULL,
data TEXT NOT NULL
);
-- For faster deleting by app_id, pushkey pair.
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
-- For faster retrieving by localpart.
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
-- Pushkey must be unique for a given user and app.
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
`
const insertPusherSQL = "" +
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
const selectPushersSQL = "" +
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
const deletePusherSQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
const deletePushersByAppIdAndPushKeySQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
func NewSQLitePusherTable(db *sql.DB) (tables.PusherTable, error) {
s := &pushersStatements{}
_, err := db.Exec(pushersSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertPusherStmt, insertPusherSQL},
{&s.selectPushersStmt, selectPushersSQL},
{&s.deletePusherStmt, deletePusherSQL},
{&s.deletePushersByAppIdAndPushKeyStmt, deletePushersByAppIdAndPushKeySQL},
}.Prepare(db)
}
type pushersStatements struct {
insertPusherStmt *sql.Stmt
selectPushersStmt *sql.Stmt
deletePusherStmt *sql.Stmt
deletePushersByAppIdAndPushKeyStmt *sql.Stmt
}
// insertPusher creates a new pusher.
// Returns an error if the user already has a pusher with the given pusher pushkey.
// Returns nil error success.
func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64,
pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
) error {
_, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
logrus.Debugf("Created pusher %d", session_id)
return err
}
func (s *pushersStatements) SelectPushers(
ctx context.Context, txn *sql.Tx, localpart string,
) ([]api.Pusher, error) {
pushers := []api.Pusher{}
rows, err := s.selectPushersStmt.QueryContext(ctx, localpart)
if err != nil {
return pushers, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectPushers: rows.close() failed")
for rows.Next() {
var pusher api.Pusher
var data []byte
err = rows.Scan(
&pusher.SessionID,
&pusher.PushKey,
&pusher.PushKeyTS,
&pusher.Kind,
&pusher.AppID,
&pusher.AppDisplayName,
&pusher.DeviceDisplayName,
&pusher.ProfileTag,
&pusher.Language,
&data)
if err != nil {
return pushers, err
}
err := json.Unmarshal(data, &pusher.Data)
if err != nil {
return pushers, err
}
pushers = append(pushers, pusher)
}
logrus.Debugf("Database returned %d pushers", len(pushers))
return pushers, rows.Err()
}
// deletePusher removes a single pusher by pushkey and user localpart.
func (s *pushersStatements) DeletePusher(
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
) error {
_, err := s.deletePusherStmt.ExecContext(ctx, appid, pushkey, localpart)
return err
}
func (s *pushersStatements) DeletePushers(
ctx context.Context, txn *sql.Tx, appid, pushkey string,
) error {
_, err := s.deletePushersByAppIdAndPushKeyStmt.ExecContext(ctx, appid, pushkey)
return err
}

View file

@ -86,6 +86,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil {
return nil, fmt.Errorf("NewSQLiteThreePIDTable: %w", err)
}
pusherTable, err := NewSQLitePusherTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresPusherTable: %w", err)
}
notificationsTable, err := NewSQLiteNotificationTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err)
}
return &shared.Database{
AccountDatas: accountDataTable,
Accounts: accountsTable,
@ -96,6 +104,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
OpenIDTokens: openIDTable,
Profiles: profilesTable,
ThreePIDs: threePIDTable,
Pushers: pusherTable,
Notifications: notificationsTable,
ServerName: serverName,
DB: db,
Writer: sqlutil.NewExclusiveWriter(),

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
type AccountDataTable interface {
@ -93,3 +94,42 @@ type ThreePIDTable interface {
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error)
DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error)
}
type PusherTable interface {
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error)
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error
DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error
}
type NotificationTable interface {
Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error
DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error)
UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error)
Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error)
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error)
}
type NotificationFilter uint32
const (
// HighlightNotifications returns notifications that had a
// "highlight" tweak assigned to them from evaluating push rules.
HighlightNotifications NotificationFilter = 1 << iota
// NonHighlightNotifications returns notifications that don't
// match HighlightNotifications.
NonHighlightNotifications
// NoNotifications is a filter to exclude all types of
// notifications. It's useful as a zero value, but isn't likely to
// be used in a call to Notifications.Select*.
NoNotifications NotificationFilter = 0
// AllNotifications is a filter to include all types of
// notifications in Notifications.Select*. Note that PostgreSQL
// balks if this doesn't fit in INTEGER, even though we use
// uint32.
AllNotifications NotificationFilter = (1 << 31) - 1
)

View file

@ -18,11 +18,17 @@ import (
"time"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/pushgateway"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/consumers"
"github.com/matrix-org/dendrite/userapi/internal"
"github.com/matrix-org/dendrite/userapi/inthttp"
"github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/sirupsen/logrus"
)
@ -36,26 +42,49 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
// NewInternalAPI returns a concerete implementation of the internal API. Callers
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(
accountDB storage.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI,
base *base.BaseDendrite, db storage.Database, cfg *config.UserAPI,
appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI,
rsAPI rsapi.RoomserverInternalAPI, pgClient pushgateway.Client,
) api.UserInternalAPI {
db, err := storage.NewDatabase(&cfg.AccountDatabase, cfg.Matrix.ServerName, cfg.BCryptCost, int64(api.DefaultLoginTokenLifetime*time.Millisecond), api.DefaultLoginTokenLifetime)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to device db")
}
return newInternalAPI(db, cfg, appServices, keyAPI)
}
js := jetstream.Prepare(&cfg.Matrix.JetStream)
func newInternalAPI(
db storage.Database,
cfg *config.UserAPI,
appServices []config.ApplicationService,
keyAPI keyapi.KeyInternalAPI,
) api.UserInternalAPI {
return &internal.UserInternalAPI{
DB: db,
ServerName: cfg.Matrix.ServerName,
AppServices: appServices,
KeyAPI: keyAPI,
syncProducer := producers.NewSyncAPI(
db, js,
// TODO: user API should handle syncs for account data. Right now,
// it's handled by clientapi, and hence uses its topic. When user
// API handles it for all account data, we can remove it from
// here.
cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData),
cfg.Matrix.JetStream.TopicFor(jetstream.OutputNotificationData),
)
userAPI := &internal.UserInternalAPI{
DB: db,
SyncProducer: syncProducer,
ServerName: cfg.Matrix.ServerName,
AppServices: appServices,
KeyAPI: keyAPI,
DisableTLSValidation: cfg.PushGatewayDisableTLSValidation,
}
readConsumer := consumers.NewOutputReadUpdateConsumer(
base.ProcessContext, cfg, js, db, pgClient, userAPI, syncProducer,
)
if err := readConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start user API read update consumer")
}
eventConsumer := consumers.NewOutputStreamEventConsumer(
base.ProcessContext, cfg, js, db, pgClient, userAPI, rsAPI, syncProducer,
)
if err := eventConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start user API streamed event consumer")
}
return userAPI
}

View file

@ -30,6 +30,7 @@ import (
"github.com/matrix-org/dendrite/internal/test"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/internal"
"github.com/matrix-org/dendrite/userapi/inthttp"
"github.com/matrix-org/dendrite/userapi/storage"
)
@ -62,7 +63,10 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, s
},
}
return newInternalAPI(accountDB, cfg, nil, nil), accountDB
return &internal.UserInternalAPI{
DB: accountDB,
ServerName: cfg.Matrix.ServerName,
}, accountDB
}
func TestQueryProfile(t *testing.T) {

100
userapi/util/devices.go Normal file
View file

@ -0,0 +1,100 @@
package util
import (
"context"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage"
log "github.com/sirupsen/logrus"
)
type PusherDevice struct {
Device pushgateway.Device
Pusher *api.Pusher
URL string
Format string
}
// GetPushDevices pushes to the configured devices of a local user.
func GetPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
pushers, err := db.GetPushers(ctx, localpart)
if err != nil {
return nil, err
}
devices := make([]*PusherDevice, 0, len(pushers))
for _, pusher := range pushers {
var url, format string
data := pusher.Data
switch pusher.Kind {
case api.EmailKind:
url = "mailto:"
case api.HTTPKind:
// TODO: The spec says only event_id_only is supported,
// but Sytests assume "" means "full notification".
fmtIface := pusher.Data["format"]
var ok bool
format, ok = fmtIface.(string)
if ok && format != "event_id_only" {
log.WithFields(log.Fields{
"localpart": localpart,
"app_id": pusher.AppID,
}).Errorf("Only data.format event_id_only or empty is supported")
continue
}
urlIface := pusher.Data["url"]
url, ok = urlIface.(string)
if !ok {
log.WithFields(log.Fields{
"localpart": localpart,
"app_id": pusher.AppID,
}).Errorf("No data.url configured for HTTP Pusher")
continue
}
data = mapWithout(data, "url")
default:
log.WithFields(log.Fields{
"localpart": localpart,
"app_id": pusher.AppID,
"kind": pusher.Kind,
}).Errorf("Unhandled pusher kind")
continue
}
devices = append(devices, &PusherDevice{
Device: pushgateway.Device{
AppID: pusher.AppID,
Data: data,
PushKey: pusher.PushKey,
PushKeyTS: pusher.PushKeyTS,
Tweaks: tweaks,
},
Pusher: &pusher,
URL: url,
Format: format,
})
}
return devices, nil
}
// mapWithout returns a shallow copy of the map, without the given
// key. Returns nil if the resulting map is empty.
func mapWithout(m map[string]interface{}, key string) map[string]interface{} {
ret := make(map[string]interface{}, len(m))
for k, v := range m {
// The specification says we do not send "url".
if k == key {
continue
}
ret[k] = v
}
if len(ret) == 0 {
return nil
}
return ret
}

76
userapi/util/notify.go Normal file
View file

@ -0,0 +1,76 @@
package util
import (
"context"
"strings"
"time"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables"
log "github.com/sirupsen/logrus"
)
// NotifyUserCountsAsync sends notifications to a local user's
// notification destinations. Database lookups run synchronously, but
// a single goroutine is started when talking to the Push
// gateways. There is no way to know when the background goroutine has
// finished.
func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, db storage.Database) error {
pusherDevices, err := GetPushDevices(ctx, localpart, nil, db)
if err != nil {
return err
}
if len(pusherDevices) == 0 {
return nil
}
userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, tables.AllNotifications)
if err != nil {
return err
}
log.WithFields(log.Fields{
"localpart": localpart,
"app_id0": pusherDevices[0].Device.AppID,
"pushkey": pusherDevices[0].Device.PushKey,
}).Tracef("Notifying HTTP push gateway about notification counts")
// TODO: think about bounding this to one per user, and what
// ordering guarantees we must provide.
go func() {
// This background processing cannot be tied to a request.
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// TODO: we could batch all devices with the same URL, but
// Sytest requires consumers/roomserver.go to do it
// one-by-one, so we do the same here.
for _, pusherDevice := range pusherDevices {
// TODO: support "email".
if !strings.HasPrefix(pusherDevice.URL, "http") {
continue
}
req := pushgateway.NotifyRequest{
Notification: pushgateway.Notification{
Counts: &pushgateway.Counts{
Unread: int(userNumUnreadNotifs),
},
Devices: []*pushgateway.Device{&pusherDevice.Device},
},
}
if err := pgClient.Notify(ctx, pusherDevice.URL, &req, &pushgateway.NotifyResponse{}); err != nil {
log.WithFields(log.Fields{
"localpart": localpart,
"app_id0": pusherDevice.Device.AppID,
"pushkey": pusherDevice.Device.PushKey,
}).WithError(err).Error("HTTP push gateway request failed")
return
}
}
}()
return nil
}