Implement Push Notifications (#1842)

* Add Pushserver component with Pushers API

Co-authored-by: Tommie Gannert <tommie@gannert.se>
Co-authored-by: Dan Peleg <dan@globekeeper.com>

* Wire Pushserver component

Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>

* Add PushGatewayClient.

The full event format is required for Sytest.

* Add a pushrules module.

* Change user API account creation to use the new pushrules module's defaults.

Introduces "scope" as required by client API, and some small field
tweaks to make some 61push Sytests pass.

* Add push rules query/put API in Pushserver.

This manipulates account data over User API, and fires sync messages
for changes. Those sync messages should, according to an existing TODO
in clientapi, be moved to userapi.

Forks clientapi/producers/syncapi.go to pushserver/ for later extension.

* Add clientapi routes for push rules to Pushserver.

A cleanup would be to move more of the name-splitting logic into
pushrules.go, to depollute routing.go.

* Output rooms.join.unread_notifications in /sync.

This is the read-side. Pushserver will be the write-side.

* Implement pushserver/storage for notifications.

* Use PushGatewayClient and the pushrules module in Pushserver's room consumer.

* Use one goroutine per user to avoid locking up the entire server for
  one bad push gateway.
* Split pushing by format.
* Send one device per push. Sytest does not support coalescing
  multiple devices into one push. Matches Synapse. Either we change
  Sytest, or remove the group-by-url-and-format logic.
* Write OutputNotificationData from push server. Sync API is already
  the consumer.

* Implement read receipt consumers in Pushserver.

Supports m.read and m.fully_read receipts.

* Add clientapi route for /unstable/notifications.

* Rename to UpsertPusher for clarity and handle pusher update

* Fix linter errors

* Ignore body.Close() error check

* Fix push server internal http wiring

* Add 40 newly passing 61push tests to whitelist

* Add next 12 newly passing 61push tests to whitelist

* Send notification data before notifying users in EDU server consumer

* NATS JetStream

* Goodbye sarama

* Fix `NewStreamTokenFromString`

* Consume on the correct topic for the roomserver

* Don't panic, NAK instead

* Move push notifications into the User API

* Don't set null values since that apparently causes Element upsetti

* Also set omitempty on conditions

* Fix bug so that we don't override the push rules unnecessarily

* Tweak defaults

* Update defaults

* More tweaks

* Move `/notifications` onto `r0`/`v3` mux

* User API will consume events and read/fully read markers from the sync API with stream positions, instead of consuming directly

Co-authored-by: Piotr Kozimor <p1996k@gmail.com>
Co-authored-by: Tommie Gannert <tommie@gannert.se>
Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
This commit is contained in:
Dan 2022-03-03 13:40:53 +02:00 committed by GitHub
parent 111f01ddc8
commit f05ce478f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
92 changed files with 5840 additions and 194 deletions

4
.gitignore vendored
View file

@ -62,5 +62,7 @@ cmd/dendrite-demo-yggdrasil/embed/fs*.go
# Test dependencies # Test dependencies
test/wasm/node_modules test/wasm/node_modules
media_store/ # Ignore complement folder when running locally
complement/
media_store/

View file

@ -318,6 +318,17 @@ user_api:
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
# Configuration for the Push Server API.
push_server:
internal_api:
listen: http://localhost:7782
connect: http://localhost:7782
database:
connection_string: postgresql://dendrite:itsasecret@postgres/dendrite_pushserver?sslmode=disable
max_open_conns: 10
max_idle_conns: 2
conn_max_lifetime: -1
# Configuration for Opentracing. # Configuration for Opentracing.
# See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on # See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on
# how this works and how to set it up. # how this works and how to set it up.

View file

@ -312,7 +312,7 @@ func (m *DendriteMonolith) Start() {
) )
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
m.userAPI = userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) m.userAPI = userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(m.userAPI) keyAPI.SetUserAPI(m.userAPI)
eduInputAPI := eduserver.NewInternalAPI( eduInputAPI := eduserver.NewInternalAPI(

View file

@ -116,7 +116,7 @@ func (m *DendriteMonolith) Start() {
) )
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(userAPI) keyAPI.SetUserAPI(userAPI)
eduInputAPI := eduserver.NewInternalAPI( eduInputAPI := eduserver.NewInternalAPI(

View file

@ -59,6 +59,7 @@ func AddPublicRoutes(
routing.Setup( routing.Setup(
router, synapseAdminRouter, cfg, eduInputAPI, rsAPI, asAPI, router, synapseAdminRouter, cfg, eduInputAPI, rsAPI, asAPI,
accountsDB, userAPI, federation, accountsDB, userAPI, federation,
syncProducer, transactionsCache, fsAPI, keyAPI, extRoomsProvider, mscCfg, syncProducer, transactionsCache, fsAPI, keyAPI,
extRoomsProvider, mscCfg,
) )
} }

View file

@ -30,7 +30,7 @@ type SyncAPIProducer struct {
} }
// SendData sends account data to the sync API server // SendData sends account data to the sync API server
func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string) error { func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string, readMarker *eventutil.ReadMarkerJSON) error {
m := &nats.Msg{ m := &nats.Msg{
Subject: p.Topic, Subject: p.Topic,
Header: nats.Header{}, Header: nats.Header{},
@ -40,6 +40,7 @@ func (p *SyncAPIProducer) SendData(userID string, roomID string, dataType string
data := eventutil.AccountData{ data := eventutil.AccountData{
RoomID: roomID, RoomID: roomID,
Type: dataType, Type: dataType,
ReadMarker: readMarker,
} }
var err error var err error
m.Data, err = json.Marshal(data) m.Data, err = json.Marshal(data)

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/clientapi/producers" "github.com/matrix-org/dendrite/clientapi/producers"
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/internal/eventutil"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
@ -127,7 +128,7 @@ func SaveAccountData(
} }
// TODO: user API should do this since it's account data // TODO: user API should do this since it's account data
if err := syncProducer.SendData(userID, roomID, dataType); err != nil { if err := syncProducer.SendData(userID, roomID, dataType, nil); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed") util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
@ -138,11 +139,6 @@ func SaveAccountData(
} }
} }
type readMarkerJSON struct {
FullyRead string `json:"m.fully_read"`
Read string `json:"m.read"`
}
type fullyReadEvent struct { type fullyReadEvent struct {
EventID string `json:"event_id"` EventID string `json:"event_id"`
} }
@ -159,7 +155,7 @@ func SaveReadMarker(
return *resErr return *resErr
} }
var r readMarkerJSON var r eventutil.ReadMarkerJSON
resErr = httputil.UnmarshalJSONRequest(req, &r) resErr = httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil { if resErr != nil {
return *resErr return *resErr
@ -189,7 +185,7 @@ func SaveReadMarker(
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
if err := syncProducer.SendData(device.UserID, roomID, "m.fully_read"); err != nil { if err := syncProducer.SendData(device.UserID, roomID, "m.fully_read", &r); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed") util.GetLogger(req.Context()).WithError(err).Error("syncProducer.SendData failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }

View file

@ -0,0 +1,63 @@
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"net/http"
"strconv"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// GetNotifications handles /_matrix/client/r0/notifications
func GetNotifications(
req *http.Request, device *userapi.Device,
userAPI userapi.UserInternalAPI,
) util.JSONResponse {
var limit int64
if limitStr := req.URL.Query().Get("limit"); limitStr != "" {
var err error
limit, err = strconv.ParseInt(limitStr, 10, 64)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("ParseInt(limit) failed")
return jsonerror.InternalServerError()
}
}
var queryRes userapi.QueryNotificationsResponse
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError()
}
err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{
Localpart: localpart,
From: req.URL.Query().Get("from"),
Limit: int(limit),
Only: req.URL.Query().Get("only"),
}, &queryRes)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryNotifications failed")
return jsonerror.InternalServerError()
}
util.GetLogger(req.Context()).WithField("from", req.URL.Query().Get("from")).WithField("limit", limit).WithField("only", req.URL.Query().Get("only")).WithField("next", queryRes.NextToken).Infof("QueryNotifications: len %d", len(queryRes.Notifications))
return util.JSONResponse{
Code: http.StatusOK,
JSON: queryRes,
}
}

View file

@ -12,6 +12,7 @@ import (
userdb "github.com/matrix-org/dendrite/userapi/storage" userdb "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
type newPasswordRequest struct { type newPasswordRequest struct {
@ -37,6 +38,11 @@ func Password(
var r newPasswordRequest var r newPasswordRequest
r.LogoutDevices = true r.LogoutDevices = true
logrus.WithFields(logrus.Fields{
"sessionId": device.SessionID,
"userId": device.UserID,
}).Debug("Changing password")
// Unmarshal the request. // Unmarshal the request.
resErr := httputil.UnmarshalJSONRequest(req, &r) resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil { if resErr != nil {
@ -116,6 +122,15 @@ func Password(
util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed") util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed")
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
pushersReq := &api.PerformPusherDeletionRequest{
Localpart: localpart,
SessionID: device.SessionID,
}
if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed")
return jsonerror.InternalServerError()
}
} }
// Return a success code. // Return a success code.

114
clientapi/routing/pusher.go Normal file
View file

@ -0,0 +1,114 @@
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing
import (
"net/http"
"net/url"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// GetPushers handles /_matrix/client/r0/pushers
func GetPushers(
req *http.Request, device *userapi.Device,
userAPI userapi.UserInternalAPI,
) util.JSONResponse {
var queryRes userapi.QueryPushersResponse
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError()
}
err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{
Localpart: localpart,
}, &queryRes)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed")
return jsonerror.InternalServerError()
}
for i := range queryRes.Pushers {
queryRes.Pushers[i].SessionID = 0
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: queryRes,
}
}
// SetPusher handles /_matrix/client/r0/pushers/set
// This endpoint allows the creation, modification and deletion of pushers for this user ID.
// The behaviour of this endpoint varies depending on the values in the JSON body.
func SetPusher(
req *http.Request, device *userapi.Device,
userAPI userapi.UserInternalAPI,
) util.JSONResponse {
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("SplitID failed")
return jsonerror.InternalServerError()
}
body := userapi.PerformPusherSetRequest{}
if resErr := httputil.UnmarshalJSONRequest(req, &body); resErr != nil {
return *resErr
}
if len(body.AppID) > 64 {
return invalidParam("length of app_id must be no more than 64 characters")
}
if len(body.PushKey) > 512 {
return invalidParam("length of pushkey must be no more than 512 bytes")
}
uInt := body.Data["url"]
if uInt != nil {
u, ok := uInt.(string)
if !ok {
return invalidParam("url must be string")
}
if u != "" {
var pushUrl *url.URL
pushUrl, err = url.Parse(u)
if err != nil {
return invalidParam("malformed url passed")
}
if pushUrl.Scheme != "https" {
return invalidParam("only https scheme is allowed")
}
}
}
body.Localpart = localpart
body.SessionID = device.SessionID
err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{})
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("PerformPusherSet failed")
return jsonerror.InternalServerError()
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}
func invalidParam(msg string) util.JSONResponse {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidParam(msg),
}
}

View file

@ -0,0 +1,386 @@
package routing
import (
"context"
"encoding/json"
"io"
"net/http"
"reflect"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/pushrules"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/util"
)
func errorResponse(ctx context.Context, err error, msg string, args ...interface{}) util.JSONResponse {
if eerr, ok := err.(*jsonerror.MatrixError); ok {
var status int
switch eerr.ErrCode {
case "M_INVALID_ARGUMENT_VALUE":
status = http.StatusBadRequest
case "M_NOT_FOUND":
status = http.StatusNotFound
default:
status = http.StatusInternalServerError
}
return util.MatrixErrorResponse(status, eerr.ErrCode, eerr.Err)
}
util.GetLogger(ctx).WithError(err).Errorf(msg, args...)
return jsonerror.InternalServerError()
}
func GetAllPushRules(ctx context.Context, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
if err != nil {
return errorResponse(ctx, err, "queryPushRulesJSON failed")
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: ruleSets,
}
}
func GetPushRulesByScope(ctx context.Context, scope string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
if err != nil {
return errorResponse(ctx, err, "queryPushRulesJSON failed")
}
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
if ruleSet == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: ruleSet,
}
}
func GetPushRulesByKind(ctx context.Context, scope, kind string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
if err != nil {
return errorResponse(ctx, err, "queryPushRules failed")
}
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
if ruleSet == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
}
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
if rulesPtr == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: *rulesPtr,
}
}
func GetPushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
if err != nil {
return errorResponse(ctx, err, "queryPushRules failed")
}
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
if ruleSet == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
}
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
if rulesPtr == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
}
i := pushRuleIndexByID(*rulesPtr, ruleID)
if i < 0 {
return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: (*rulesPtr)[i],
}
}
func PutPushRuleByRuleID(ctx context.Context, scope, kind, ruleID, afterRuleID, beforeRuleID string, body io.Reader, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
var newRule pushrules.Rule
if err := json.NewDecoder(body).Decode(&newRule); err != nil {
return errorResponse(ctx, err, "JSON Decode failed")
}
newRule.RuleID = ruleID
errs := pushrules.ValidateRule(pushrules.Kind(kind), &newRule)
if len(errs) > 0 {
return errorResponse(ctx, jsonerror.InvalidArgumentValue(errs[0].Error()), "rule sanity check failed: %v", errs)
}
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
if err != nil {
return errorResponse(ctx, err, "queryPushRules failed")
}
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
if ruleSet == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
}
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
if rulesPtr == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
}
i := pushRuleIndexByID(*rulesPtr, ruleID)
if i >= 0 && afterRuleID == "" && beforeRuleID == "" {
// Modify rule at the same index.
// TODO: The spec does not say what to do in this case, but
// this feels reasonable.
*((*rulesPtr)[i]) = newRule
util.GetLogger(ctx).Infof("Modified existing push rule at %d", i)
} else {
if i >= 0 {
// Delete old rule.
*rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...)
util.GetLogger(ctx).Infof("Deleted old push rule at %d", i)
} else {
// SPEC: When creating push rules, they MUST be enabled by default.
//
// TODO: it's unclear if we must reject disabled rules, or force
// the value to true. Sytests fail if we don't force it.
newRule.Enabled = true
}
// Add new rule.
i, err := findPushRuleInsertionIndex(*rulesPtr, afterRuleID, beforeRuleID)
if err != nil {
return errorResponse(ctx, err, "findPushRuleInsertionIndex failed")
}
*rulesPtr = append((*rulesPtr)[:i], append([]*pushrules.Rule{&newRule}, (*rulesPtr)[i:]...)...)
util.GetLogger(ctx).WithField("after", afterRuleID).WithField("before", beforeRuleID).Infof("Added new push rule at %d", i)
}
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
return errorResponse(ctx, err, "putPushRules failed")
}
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
}
func DeletePushRuleByRuleID(ctx context.Context, scope, kind, ruleID string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
if err != nil {
return errorResponse(ctx, err, "queryPushRules failed")
}
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
if ruleSet == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
}
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
if rulesPtr == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
}
i := pushRuleIndexByID(*rulesPtr, ruleID)
if i < 0 {
return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
}
*rulesPtr = append((*rulesPtr)[:i], (*rulesPtr)[i+1:]...)
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
return errorResponse(ctx, err, "putPushRules failed")
}
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
}
func GetPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr string, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
attrGet, err := pushRuleAttrGetter(attr)
if err != nil {
return errorResponse(ctx, err, "pushRuleAttrGetter failed")
}
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
if err != nil {
return errorResponse(ctx, err, "queryPushRules failed")
}
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
if ruleSet == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
}
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
if rulesPtr == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
}
i := pushRuleIndexByID(*rulesPtr, ruleID)
if i < 0 {
return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: map[string]interface{}{
attr: attrGet((*rulesPtr)[i]),
},
}
}
func PutPushRuleAttrByRuleID(ctx context.Context, scope, kind, ruleID, attr string, body io.Reader, device *userapi.Device, userAPI userapi.UserInternalAPI) util.JSONResponse {
var newPartialRule pushrules.Rule
if err := json.NewDecoder(body).Decode(&newPartialRule); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON(err.Error()),
}
}
if newPartialRule.Actions == nil {
// This ensures json.Marshal encodes the empty list as [] rather than null.
newPartialRule.Actions = []*pushrules.Action{}
}
attrGet, err := pushRuleAttrGetter(attr)
if err != nil {
return errorResponse(ctx, err, "pushRuleAttrGetter failed")
}
attrSet, err := pushRuleAttrSetter(attr)
if err != nil {
return errorResponse(ctx, err, "pushRuleAttrSetter failed")
}
ruleSets, err := queryPushRules(ctx, device.UserID, userAPI)
if err != nil {
return errorResponse(ctx, err, "queryPushRules failed")
}
ruleSet := pushRuleSetByScope(ruleSets, pushrules.Scope(scope))
if ruleSet == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rule set"), "pushRuleSetByScope failed")
}
rulesPtr := pushRuleSetKindPointer(ruleSet, pushrules.Kind(kind))
if rulesPtr == nil {
return errorResponse(ctx, jsonerror.InvalidArgumentValue("invalid push rules kind"), "pushRuleSetKindPointer failed")
}
i := pushRuleIndexByID(*rulesPtr, ruleID)
if i < 0 {
return errorResponse(ctx, jsonerror.NotFound("push rule ID not found"), "pushRuleIndexByID failed")
}
if !reflect.DeepEqual(attrGet((*rulesPtr)[i]), attrGet(&newPartialRule)) {
attrSet((*rulesPtr)[i], &newPartialRule)
if err := putPushRules(ctx, device.UserID, ruleSets, userAPI); err != nil {
return errorResponse(ctx, err, "putPushRules failed")
}
}
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
}
func queryPushRules(ctx context.Context, userID string, userAPI userapi.UserInternalAPI) (*pushrules.AccountRuleSets, error) {
var res userapi.QueryPushRulesResponse
if err := userAPI.QueryPushRules(ctx, &userapi.QueryPushRulesRequest{UserID: userID}, &res); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.QueryPushRules failed")
return nil, err
}
return res.RuleSets, nil
}
func putPushRules(ctx context.Context, userID string, ruleSets *pushrules.AccountRuleSets, userAPI userapi.UserInternalAPI) error {
req := userapi.PerformPushRulesPutRequest{
UserID: userID,
RuleSets: ruleSets,
}
var res struct{}
if err := userAPI.PerformPushRulesPut(ctx, &req, &res); err != nil {
util.GetLogger(ctx).WithError(err).Error("userAPI.PerformPushRulesPut failed")
return err
}
return nil
}
func pushRuleSetByScope(ruleSets *pushrules.AccountRuleSets, scope pushrules.Scope) *pushrules.RuleSet {
switch scope {
case pushrules.GlobalScope:
return &ruleSets.Global
default:
return nil
}
}
func pushRuleSetKindPointer(ruleSet *pushrules.RuleSet, kind pushrules.Kind) *[]*pushrules.Rule {
switch kind {
case pushrules.OverrideKind:
return &ruleSet.Override
case pushrules.ContentKind:
return &ruleSet.Content
case pushrules.RoomKind:
return &ruleSet.Room
case pushrules.SenderKind:
return &ruleSet.Sender
case pushrules.UnderrideKind:
return &ruleSet.Underride
default:
return nil
}
}
func pushRuleIndexByID(rules []*pushrules.Rule, id string) int {
for i, rule := range rules {
if rule.RuleID == id {
return i
}
}
return -1
}
func pushRuleAttrGetter(attr string) (func(*pushrules.Rule) interface{}, error) {
switch attr {
case "actions":
return func(rule *pushrules.Rule) interface{} { return rule.Actions }, nil
case "enabled":
return func(rule *pushrules.Rule) interface{} { return rule.Enabled }, nil
default:
return nil, jsonerror.InvalidArgumentValue("invalid push rule attribute")
}
}
func pushRuleAttrSetter(attr string) (func(dest, src *pushrules.Rule), error) {
switch attr {
case "actions":
return func(dest, src *pushrules.Rule) { dest.Actions = src.Actions }, nil
case "enabled":
return func(dest, src *pushrules.Rule) { dest.Enabled = src.Enabled }, nil
default:
return nil, jsonerror.InvalidArgumentValue("invalid push rule attribute")
}
}
func findPushRuleInsertionIndex(rules []*pushrules.Rule, afterID, beforeID string) (int, error) {
var i int
if afterID != "" {
for ; i < len(rules); i++ {
if rules[i].RuleID == afterID {
break
}
}
if i == len(rules) {
return 0, jsonerror.NotFound("after: rule ID not found")
}
if rules[i].Default {
return 0, jsonerror.NotFound("after: rule ID must not be a default rule")
}
// We stopped on the "after" match to differentiate
// not-found from is-last-entry. Now we move to the earliest
// insertion point.
i++
}
if beforeID != "" {
for ; i < len(rules); i++ {
if rules[i].RuleID == beforeID {
break
}
}
if i == len(rules) {
return 0, jsonerror.NotFound("before: rule ID not found")
}
if rules[i].Default {
return 0, jsonerror.NotFound("before: rule ID must not be a default rule")
}
}
// UNSPEC: The spec does not say what to do if no after/before is
// given. Sytest fails if it doesn't go first.
return i, nil
}

View file

@ -98,7 +98,7 @@ func PutTag(
return jsonerror.InternalServerError() return jsonerror.InternalServerError()
} }
if err = syncProducer.SendData(userID, roomID, "m.tag"); err != nil { if err = syncProducer.SendData(userID, roomID, "m.tag", nil); err != nil {
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
} }
@ -151,7 +151,7 @@ func DeleteTag(
} }
// TODO: user API should do this since it's account data // TODO: user API should do this since it's account data
if err := syncProducer.SendData(userID, roomID, "m.tag"); err != nil { if err := syncProducer.SendData(userID, roomID, "m.tag", nil); err != nil {
logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi") logrus.WithError(err).Error("Failed to send m.tag account data update to syncapi")
} }

View file

@ -16,7 +16,6 @@ package routing
import ( import (
"context" "context"
"encoding/json"
"net/http" "net/http"
"strings" "strings"
@ -561,25 +560,142 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
v3mux.Handle("/pushrules/", // Push rules
httputil.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse {
// TODO: Implement push rules API v3mux.Handle("/pushrules",
res := json.RawMessage(`{ httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
"global": {
"content": [],
"override": [],
"room": [],
"sender": [],
"underride": []
}
}`)
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusBadRequest,
JSON: &res, JSON: jsonerror.InvalidArgumentValue("missing trailing slash"),
} }
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/pushrules/",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetAllPushRules(req.Context(), device, userAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/pushrules/",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("scope, kind and rule ID must be specified"),
}
}),
).Methods(http.MethodPut)
v3mux.Handle("/pushrules/{scope}/",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return GetPushRulesByScope(req.Context(), vars["scope"], device, userAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/pushrules/{scope}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("missing trailing slash after scope"),
}
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/pushrules/{scope:[^/]+/?}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("kind and rule ID must be specified"),
}
}),
).Methods(http.MethodPut)
v3mux.Handle("/pushrules/{scope}/{kind}/",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return GetPushRulesByKind(req.Context(), vars["scope"], vars["kind"], device, userAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/pushrules/{scope}/{kind}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("missing trailing slash after kind"),
}
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/pushrules/{scope}/{kind:[^/]+/?}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("rule ID must be specified"),
}
}),
).Methods(http.MethodPut)
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return GetPushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], device, userAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
}
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
query := req.URL.Query()
return PutPushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], query.Get("after"), query.Get("before"), req.Body, device, userAPI)
}),
).Methods(http.MethodPut)
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return DeletePushRuleByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], device, userAPI)
}),
).Methods(http.MethodDelete)
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return GetPushRuleAttrByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], vars["attr"], device, userAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/pushrules/{scope}/{kind}/{ruleID}/{attr}",
httputil.MakeAuthAPI("push_rules", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
return PutPushRuleAttrByRuleID(req.Context(), vars["scope"], vars["kind"], vars["ruleID"], vars["attr"], req.Body, device, userAPI)
}),
).Methods(http.MethodPut)
// Element user settings // Element user settings
v3mux.Handle("/profile/{userID}", v3mux.Handle("/profile/{userID}",
@ -885,6 +1001,27 @@ func Setup(
}), }),
).Methods(http.MethodPost, http.MethodOptions) ).Methods(http.MethodPost, http.MethodOptions)
v3mux.Handle("/notifications",
httputil.MakeAuthAPI("get_notifications", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetNotifications(req, device, userAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/pushers",
httputil.MakeAuthAPI("get_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetPushers(req, device, userAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
v3mux.Handle("/pushers/set",
httputil.MakeAuthAPI("set_pushers", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil {
return *r
}
return SetPusher(req, device, userAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
// Stub implementations for sytest // Stub implementations for sytest
v3mux.Handle("/events", v3mux.Handle("/events",
httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {

View file

@ -144,12 +144,14 @@ func main() {
accountDB := base.Base.CreateAccountsDB() accountDB := base.Base.CreateAccountsDB()
federation := createFederationClient(base) federation := createFederationClient(base)
keyAPI := keyserver.NewInternalAPI(&base.Base, &base.Base.Cfg.KeyServer, federation) keyAPI := keyserver.NewInternalAPI(&base.Base, &base.Base.Cfg.KeyServer, federation)
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
keyAPI.SetUserAPI(userAPI)
rsAPI := roomserver.NewInternalAPI( rsAPI := roomserver.NewInternalAPI(
&base.Base, &base.Base,
) )
userAPI := userapi.NewInternalAPI(&base.Base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.Base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(userAPI)
eduInputAPI := eduserver.NewInternalAPI( eduInputAPI := eduserver.NewInternalAPI(
&base.Base, cache.New(), userAPI, &base.Base, cache.New(), userAPI,
) )

View file

@ -187,7 +187,7 @@ func main() {
) )
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI) userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(userAPI) keyAPI.SetUserAPI(userAPI)
eduInputAPI := eduserver.NewInternalAPI( eduInputAPI := eduserver.NewInternalAPI(

View file

@ -111,14 +111,15 @@ func main() {
keyRing := serverKeyAPI.KeyRing() keyRing := serverKeyAPI.KeyRing()
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
keyAPI.SetUserAPI(userAPI)
rsComponent := roomserver.NewInternalAPI( rsComponent := roomserver.NewInternalAPI(
base, base,
) )
rsAPI := rsComponent rsAPI := rsComponent
userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(userAPI)
eduInputAPI := eduserver.NewInternalAPI( eduInputAPI := eduserver.NewInternalAPI(
base, cache.New(), userAPI, base, cache.New(), userAPI,
) )

View file

@ -106,7 +106,8 @@ func main() {
keyAPI = base.KeyServerHTTPClient() keyAPI = base.KeyServerHTTPClient()
} }
userImpl := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI) pgClient := base.PushGatewayHTTPClient()
userImpl := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, pgClient)
userAPI := userImpl userAPI := userImpl
if base.UseHTTPAPIs { if base.UseHTTPAPIs {
userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) userapi.AddInternalRoutes(base.InternalAPIMux, userAPI)

View file

@ -23,7 +23,11 @@ import (
func UserAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { func UserAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) {
accountDB := base.CreateAccountsDB() accountDB := base.CreateAccountsDB()
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices, base.KeyServerHTTPClient()) userAPI := userapi.NewInternalAPI(
base, accountDB, &cfg.UserAPI, cfg.Derived.ApplicationServices,
base.KeyServerHTTPClient(), base.RoomserverHTTPClient(),
base.PushGatewayHTTPClient(),
)
userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) userapi.AddInternalRoutes(base.InternalAPIMux, userAPI)

View file

@ -184,13 +184,15 @@ func startup() {
accountDB := base.CreateAccountsDB() accountDB := base.CreateAccountsDB()
federation := conn.CreateFederationClient(base, pSessions) federation := conn.CreateFederationClient(base, pSessions)
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
userAPI := userapi.NewInternalAPI(accountDB, &cfg.UserAPI, nil, keyAPI)
keyAPI.SetUserAPI(userAPI)
serverKeyAPI := &signing.YggdrasilKeys{} serverKeyAPI := &signing.YggdrasilKeys{}
keyRing := serverKeyAPI.KeyRing() keyRing := serverKeyAPI.KeyRing()
rsAPI := roomserver.NewInternalAPI(base) rsAPI := roomserver.NewInternalAPI(base)
userAPI := userapi.NewInternalAPI(base, accountDB, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
keyAPI.SetUserAPI(userAPI)
eduInputAPI := eduserver.NewInternalAPI(base, cache.New(), userAPI) eduInputAPI := eduserver.NewInternalAPI(base, cache.New(), userAPI)
asQuery := appservice.NewInternalAPI( asQuery := appservice.NewInternalAPI(
base, userAPI, rsAPI, base, userAPI, rsAPI,

View file

@ -212,6 +212,8 @@ func main() {
rsAPI.SetFederationAPI(fedSenderAPI, keyRing) rsAPI.SetFederationAPI(fedSenderAPI, keyRing)
p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI, federation) p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI, federation)
psAPI := pushserver.NewInternalAPI(base)
monolith := setup.Monolith{ monolith := setup.Monolith{
Config: base.Cfg, Config: base.Cfg,
AccountDB: accountDB, AccountDB: accountDB,
@ -225,6 +227,7 @@ func main() {
RoomserverAPI: rsAPI, RoomserverAPI: rsAPI,
UserAPI: userAPI, UserAPI: userAPI,
KeyAPI: keyAPI, KeyAPI: keyAPI,
PushserverAPI: psAPI,
//ServerKeyAPI: serverKeyAPI, //ServerKeyAPI: serverKeyAPI,
ExtPublicRoomsProvider: p2pPublicRoomProvider, ExtPublicRoomsProvider: p2pPublicRoomProvider,
} }

1
go.mod
View file

@ -18,6 +18,7 @@ require (
github.com/frankban/quicktest v1.14.0 // indirect github.com/frankban/quicktest v1.14.0 // indirect
github.com/getsentry/sentry-go v0.12.0 github.com/getsentry/sentry-go v0.12.0
github.com/gologme/log v1.3.0 github.com/gologme/log v1.3.0
github.com/google/go-cmp v0.5.6
github.com/google/uuid v1.2.0 github.com/google/uuid v1.2.0
github.com/gorilla/mux v1.8.0 github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.4.2 github.com/gorilla/websocket v1.4.2

View file

@ -28,6 +28,28 @@ var ErrProfileNoExists = errors.New("no known profile for given user ID")
type AccountData struct { type AccountData struct {
RoomID string `json:"room_id"` RoomID string `json:"room_id"`
Type string `json:"type"` Type string `json:"type"`
ReadMarker *ReadMarkerJSON `json:"read_marker,omitempty"` // optional
}
type ReadMarkerJSON struct {
FullyRead string `json:"m.fully_read"`
Read string `json:"m.read"`
}
// NotificationData contains statistics about notifications, sent from
// the Push Server to the Sync API server.
type NotificationData struct {
// RoomID identifies the scope of the statistics, together with
// MXID (which is encoded in the Kafka key).
RoomID string `json:"room_id"`
// HighlightCount is the number of unread notifications with the
// highlight tweak.
UnreadHighlightCount int `json:"unread_highlight_count"`
// UnreadNotificationCount is the total number of unread
// notifications.
UnreadNotificationCount int `json:"unread_notification_count"`
} }
// ProfileResponse is a struct containing all known user profile data // ProfileResponse is a struct containing all known user profile data

View file

@ -0,0 +1,66 @@
package pushgateway
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/opentracing/opentracing-go"
)
type httpClient struct {
hc *http.Client
}
// NewHTTPClient creates a new Push Gateway client.
func NewHTTPClient(disableTLSValidation bool) Client {
hc := &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
DisableKeepAlives: true,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: disableTLSValidation,
},
},
}
return &httpClient{hc: hc}
}
func (h *httpClient) Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "Notify")
defer span.Finish()
body, err := json.Marshal(req)
if err != nil {
return err
}
hreq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return err
}
hreq.Header.Set("Content-Type", "application/json")
hresp, err := h.hc.Do(hreq)
if err != nil {
return err
}
//nolint:errcheck
defer hresp.Body.Close()
if hresp.StatusCode == http.StatusOK {
return json.NewDecoder(hresp.Body).Decode(resp)
}
var errorBody struct {
Message string `json:"message"`
}
if err := json.NewDecoder(hresp.Body).Decode(&errorBody); err == nil {
return fmt.Errorf("push gateway: %d from %s: %s", hresp.StatusCode, url, errorBody.Message)
}
return fmt.Errorf("push gateway: %d from %s", hresp.StatusCode, url)
}

View file

@ -0,0 +1,62 @@
package pushgateway
import (
"context"
"encoding/json"
"github.com/matrix-org/gomatrixserverlib"
)
// A Client is how interactions with a Push Gateway is done.
type Client interface {
// Notify sends a notification to the gateway at the given URL.
Notify(ctx context.Context, url string, req *NotifyRequest, resp *NotifyResponse) error
}
type NotifyRequest struct {
Notification Notification `json:"notification"` // Required
}
type NotifyResponse struct {
// Rejected is the list of device push keys that were rejected
// during the push. The caller should remove the push keys so they
// are not used again.
Rejected []string `json:"rejected"` // Required
}
type Notification struct {
Content json.RawMessage `json:"content,omitempty"`
Counts *Counts `json:"counts,omitempty"`
Devices []*Device `json:"devices"` // Required
EventID string `json:"event_id,omitempty"`
ID string `json:"id,omitempty"` // Deprecated name for EventID.
Membership string `json:"membership,omitempty"` // UNSPEC: required for Sytest.
Prio Prio `json:"prio,omitempty"`
RoomAlias string `json:"room_alias,omitempty"`
RoomID string `json:"room_id,omitempty"`
RoomName string `json:"room_name,omitempty"`
Sender string `json:"sender,omitempty"`
SenderDisplayName string `json:"sender_display_name,omitempty"`
Type string `json:"type,omitempty"`
UserIsTarget bool `json:"user_is_target,omitempty"`
}
type Counts struct {
MissedCalls int `json:"missed_calls,omitempty"`
Unread int `json:"unread"` // TODO: UNSPEC: the spec says zero must be omitted, but Sytest 61push/01message-pushed.pl requires it.
}
type Device struct {
AppID string `json:"app_id"` // Required
Data map[string]interface{} `json:"data"` // Required. UNSPEC: Sytests require this to allow unknown keys.
PushKey string `json:"pushkey"` // Required
PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"`
Tweaks map[string]interface{} `json:"tweaks,omitempty"`
}
type Prio string
const (
HighPrio Prio = "high"
LowPrio Prio = "low"
)

View file

@ -0,0 +1,102 @@
package pushrules
import (
"bytes"
"encoding/json"
"fmt"
)
// An Action is (part of) an outcome of a rule. There are
// (unofficially) terminal actions, and modifier actions.
type Action struct {
// Kind is the type of action. Has custom encoding in JSON.
Kind ActionKind `json:"-"`
// Tweak is the property to tweak. Has custom encoding in JSON.
Tweak TweakKey `json:"-"`
// Value is some value interpreted according to Kind and Tweak.
Value interface{} `json:"value"`
}
func (a *Action) MarshalJSON() ([]byte, error) {
if a.Tweak == UnknownTweak && a.Value == nil {
return json.Marshal(a.Kind)
}
if a.Kind != SetTweakAction {
return nil, fmt.Errorf("only set_tweak actions may have a value, but got kind %q", a.Kind)
}
m := map[string]interface{}{
string(a.Kind): a.Tweak,
}
if a.Value != nil {
m["value"] = a.Value
}
return json.Marshal(m)
}
func (a *Action) UnmarshalJSON(bs []byte) error {
if bytes.HasPrefix(bs, []byte("\"")) {
return json.Unmarshal(bs, &a.Kind)
}
var raw struct {
SetTweak TweakKey `json:"set_tweak"`
Value interface{} `json:"value"`
}
if err := json.Unmarshal(bs, &raw); err != nil {
return err
}
if raw.SetTweak == UnknownTweak {
return fmt.Errorf("got unknown action JSON: %s", string(bs))
}
a.Kind = SetTweakAction
a.Tweak = raw.SetTweak
if raw.Value != nil {
a.Value = raw.Value
}
return nil
}
// ActionKind is the primary discriminator for actions.
type ActionKind string
const (
UnknownAction ActionKind = ""
// NotifyAction indicates the clients should show a notification.
NotifyAction ActionKind = "notify"
// DontNotifyAction indicates the clients should not show a notification.
DontNotifyAction ActionKind = "dont_notify"
// CoalesceAction tells the clients to show a notification, and
// tells both servers and clients that multiple events can be
// coalesced into a single notification. The behaviour is
// implementation-specific.
CoalesceAction ActionKind = "coalesce"
// SetTweakAction uses the Tweak and Value fields to add a
// tweak. Multiple SetTweakAction can be provided in a rule,
// combined with NotifyAction or CoalesceAction.
SetTweakAction ActionKind = "set_tweak"
)
// A TweakKey describes a property to be modified/tweaked for events
// that match the rule.
type TweakKey string
const (
UnknownTweak TweakKey = ""
// SoundTweak describes which sound to play. Using "default" means
// "enable sound".
SoundTweak TweakKey = "sound"
// HighlightTweak asks the clients to highlight the conversation.
HighlightTweak TweakKey = "highlight"
)

View file

@ -0,0 +1,39 @@
package pushrules
import (
"encoding/json"
"fmt"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestActionJSON(t *testing.T) {
tsts := []struct {
Want Action
}{
{Action{Kind: NotifyAction}},
{Action{Kind: DontNotifyAction}},
{Action{Kind: CoalesceAction}},
{Action{Kind: SetTweakAction}},
{Action{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}},
{Action{Kind: SetTweakAction, Tweak: HighlightTweak}},
{Action{Kind: SetTweakAction, Tweak: HighlightTweak, Value: "false"}},
}
for _, tst := range tsts {
t.Run(fmt.Sprintf("%+v", tst.Want), func(t *testing.T) {
bs, err := json.Marshal(&tst.Want)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
var got Action
if err := json.Unmarshal(bs, &got); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if diff := cmp.Diff(tst.Want, got); diff != "" {
t.Errorf("+got -want:\n%s", diff)
}
})
}
}

View file

@ -0,0 +1,49 @@
package pushrules
// A Condition dictates extra conditions for a matching rules. See
// ConditionKind.
type Condition struct {
// Kind is the primary discriminator for the condition
// type. Required.
Kind ConditionKind `json:"kind"`
// Key indicates the dot-separated path of Event fields to
// match. Required for EventMatchCondition and
// SenderNotificationPermissionCondition.
Key string `json:"key,omitempty"`
// Pattern indicates the value pattern that must match. Required
// for EventMatchCondition.
Pattern string `json:"pattern,omitempty"`
// Is indicates the condition that must be fulfilled. Required for
// RoomMemberCountCondition.
Is string `json:"is,omitempty"`
}
// ConditionKind represents a kind of condition.
//
// SPEC: Unrecognised conditions MUST NOT match any events,
// effectively making the push rule disabled.
type ConditionKind string
const (
UnknownCondition ConditionKind = ""
// EventMatchCondition indicates the condition looks for a key
// path and matches a pattern. How paths that don't reference a
// simple value match against rules is implementation-specific.
EventMatchCondition ConditionKind = "event_match"
// ContainsDisplayNameCondition indicates the current user's
// display name must be found in the content body.
ContainsDisplayNameCondition ConditionKind = "contains_display_name"
// RoomMemberCountCondition matches a simple arithmetic comparison
// against the total number of members in a room.
RoomMemberCountCondition ConditionKind = "room_member_count"
// SenderNotificationPermissionCondition compares power level for
// the sender in the event's room.
SenderNotificationPermissionCondition ConditionKind = "sender_notification_permission"
)

View file

@ -0,0 +1,23 @@
package pushrules
import (
"github.com/matrix-org/gomatrixserverlib"
)
// DefaultAccountRuleSets is the complete set of default push rules
// for an account.
func DefaultAccountRuleSets(localpart string, serverName gomatrixserverlib.ServerName) *AccountRuleSets {
return &AccountRuleSets{
Global: *DefaultGlobalRuleSet(localpart, serverName),
}
}
// DefaultGlobalRuleSet returns the default ruleset for a given (fully
// qualified) MXID.
func DefaultGlobalRuleSet(localpart string, serverName gomatrixserverlib.ServerName) *RuleSet {
return &RuleSet{
Override: defaultOverrideRules("@" + localpart + ":" + string(serverName)),
Content: defaultContentRules(localpart),
Underride: defaultUnderrideRules,
}
}

View file

@ -0,0 +1,33 @@
package pushrules
func defaultContentRules(localpart string) []*Rule {
return []*Rule{
mRuleContainsUserNameDefinition(localpart),
}
}
const (
MRuleContainsUserName = ".m.rule.contains_user_name"
)
func mRuleContainsUserNameDefinition(localpart string) *Rule {
return &Rule{
RuleID: MRuleContainsUserName,
Default: true,
Enabled: true,
Pattern: localpart,
Actions: []*Action{
{Kind: NotifyAction},
{
Kind: SetTweakAction,
Tweak: SoundTweak,
Value: "default",
},
{
Kind: SetTweakAction,
Tweak: HighlightTweak,
Value: true,
},
},
}
}

View file

@ -0,0 +1,165 @@
package pushrules
func defaultOverrideRules(userID string) []*Rule {
return []*Rule{
&mRuleMasterDefinition,
&mRuleSuppressNoticesDefinition,
mRuleInviteForMeDefinition(userID),
&mRuleMemberEventDefinition,
&mRuleContainsDisplayNameDefinition,
&mRuleTombstoneDefinition,
&mRuleRoomNotifDefinition,
}
}
const (
MRuleMaster = ".m.rule.master"
MRuleSuppressNotices = ".m.rule.suppress_notices"
MRuleInviteForMe = ".m.rule.invite_for_me"
MRuleMemberEvent = ".m.rule.member_event"
MRuleContainsDisplayName = ".m.rule.contains_display_name"
MRuleTombstone = ".m.rule.tombstone"
MRuleRoomNotif = ".m.rule.roomnotif"
)
var (
mRuleMasterDefinition = Rule{
RuleID: MRuleMaster,
Default: true,
Enabled: false,
Conditions: []*Condition{},
Actions: []*Action{{Kind: DontNotifyAction}},
}
mRuleSuppressNoticesDefinition = Rule{
RuleID: MRuleSuppressNotices,
Default: true,
Enabled: true,
Conditions: []*Condition{
{
Kind: EventMatchCondition,
Key: "content.msgtype",
Pattern: "m.notice",
},
},
Actions: []*Action{{Kind: DontNotifyAction}},
}
mRuleMemberEventDefinition = Rule{
RuleID: MRuleMemberEvent,
Default: true,
Enabled: true,
Conditions: []*Condition{
{
Kind: EventMatchCondition,
Key: "type",
Pattern: "m.room.member",
},
},
Actions: []*Action{{Kind: DontNotifyAction}},
}
mRuleContainsDisplayNameDefinition = Rule{
RuleID: MRuleContainsDisplayName,
Default: true,
Enabled: true,
Conditions: []*Condition{{Kind: ContainsDisplayNameCondition}},
Actions: []*Action{
{Kind: NotifyAction},
{
Kind: SetTweakAction,
Tweak: SoundTweak,
Value: "default",
},
{
Kind: SetTweakAction,
Tweak: HighlightTweak,
Value: true,
},
},
}
mRuleTombstoneDefinition = Rule{
RuleID: MRuleTombstone,
Default: true,
Enabled: true,
Conditions: []*Condition{
{
Kind: EventMatchCondition,
Key: "type",
Pattern: "m.room.tombstone",
},
{
Kind: EventMatchCondition,
Key: "state_key",
Pattern: "",
},
},
Actions: []*Action{
{Kind: NotifyAction},
{
Kind: SetTweakAction,
Tweak: HighlightTweak,
Value: false,
},
},
}
mRuleRoomNotifDefinition = Rule{
RuleID: MRuleRoomNotif,
Default: true,
Enabled: true,
Conditions: []*Condition{
{
Kind: EventMatchCondition,
Key: "content.body",
Pattern: "@room",
},
{
Kind: SenderNotificationPermissionCondition,
Key: "room",
},
},
Actions: []*Action{
{Kind: NotifyAction},
{
Kind: SetTweakAction,
Tweak: HighlightTweak,
Value: false,
},
},
}
)
func mRuleInviteForMeDefinition(userID string) *Rule {
return &Rule{
RuleID: MRuleInviteForMe,
Default: true,
Enabled: true,
Conditions: []*Condition{
{
Kind: EventMatchCondition,
Key: "type",
Pattern: "m.room.member",
},
{
Kind: EventMatchCondition,
Key: "content.membership",
Pattern: "invite",
},
{
Kind: EventMatchCondition,
Key: "state_key",
Pattern: userID,
},
},
Actions: []*Action{
{Kind: NotifyAction},
{
Kind: SetTweakAction,
Tweak: SoundTweak,
Value: "default",
},
{
Kind: SetTweakAction,
Tweak: HighlightTweak,
Value: false,
},
},
}
}

View file

@ -0,0 +1,119 @@
package pushrules
const (
MRuleCall = ".m.rule.call"
MRuleEncryptedRoomOneToOne = ".m.rule.encrypted_room_one_to_one"
MRuleRoomOneToOne = ".m.rule.room_one_to_one"
MRuleMessage = ".m.rule.message"
MRuleEncrypted = ".m.rule.encrypted"
)
var defaultUnderrideRules = []*Rule{
&mRuleCallDefinition,
&mRuleEncryptedRoomOneToOneDefinition,
&mRuleRoomOneToOneDefinition,
&mRuleMessageDefinition,
&mRuleEncryptedDefinition,
}
var (
mRuleCallDefinition = Rule{
RuleID: MRuleCall,
Default: true,
Enabled: true,
Conditions: []*Condition{
{
Kind: EventMatchCondition,
Key: "type",
Pattern: "m.call.invite",
},
},
Actions: []*Action{
{Kind: NotifyAction},
{
Kind: SetTweakAction,
Tweak: SoundTweak,
Value: "ring",
},
{
Kind: SetTweakAction,
Tweak: HighlightTweak,
Value: false,
},
},
}
mRuleEncryptedRoomOneToOneDefinition = Rule{
RuleID: MRuleEncryptedRoomOneToOne,
Default: true,
Enabled: true,
Conditions: []*Condition{
{
Kind: RoomMemberCountCondition,
Is: "2",
},
{
Kind: EventMatchCondition,
Key: "type",
Pattern: "m.room.encrypted",
},
},
Actions: []*Action{
{Kind: NotifyAction},
{
Kind: SetTweakAction,
Tweak: HighlightTweak,
Value: false,
},
},
}
mRuleRoomOneToOneDefinition = Rule{
RuleID: MRuleRoomOneToOne,
Default: true,
Enabled: true,
Conditions: []*Condition{
{
Kind: RoomMemberCountCondition,
Is: "2",
},
{
Kind: EventMatchCondition,
Key: "type",
Pattern: "m.room.message",
},
},
Actions: []*Action{
{Kind: NotifyAction},
{
Kind: SetTweakAction,
Tweak: HighlightTweak,
Value: false,
},
},
}
mRuleMessageDefinition = Rule{
RuleID: MRuleMessage,
Default: true,
Enabled: true,
Conditions: []*Condition{
{
Kind: EventMatchCondition,
Key: "type",
Pattern: "m.room.message",
},
},
Actions: []*Action{{Kind: NotifyAction}},
}
mRuleEncryptedDefinition = Rule{
RuleID: MRuleEncrypted,
Default: true,
Enabled: true,
Conditions: []*Condition{
{
Kind: EventMatchCondition,
Key: "type",
Pattern: "m.room.encrypted",
},
},
Actions: []*Action{{Kind: NotifyAction}},
}
)

View file

@ -0,0 +1,165 @@
package pushrules
import (
"encoding/json"
"fmt"
"strings"
"github.com/matrix-org/gomatrixserverlib"
)
// A RuleSetEvaluator encapsulates context to evaluate an event
// against a rule set.
type RuleSetEvaluator struct {
ec EvaluationContext
ruleSet []kindAndRules
}
// An EvaluationContext gives a RuleSetEvaluator access to the
// environment, for rules that require that.
type EvaluationContext interface {
// UserDisplayName returns the current user's display name.
UserDisplayName() string
// RoomMemberCount returns the number of members in the room of
// the current event.
RoomMemberCount() (int, error)
// HasPowerLevel returns whether the user has at least the given
// power in the room of the current event.
HasPowerLevel(userID, levelKey string) (bool, error)
}
// A kindAndRules is just here to simplify iteration of the (ordered)
// kinds of rules.
type kindAndRules struct {
Kind Kind
Rules []*Rule
}
// NewRuleSetEvaluator creates a new evaluator for the given rule set.
func NewRuleSetEvaluator(ec EvaluationContext, ruleSet *RuleSet) *RuleSetEvaluator {
return &RuleSetEvaluator{
ec: ec,
ruleSet: []kindAndRules{
{OverrideKind, ruleSet.Override},
{ContentKind, ruleSet.Content},
{RoomKind, ruleSet.Room},
{SenderKind, ruleSet.Sender},
{UnderrideKind, ruleSet.Underride},
},
}
}
// MatchEvent returns the first matching rule. Returns nil if there
// was no match rule.
func (rse *RuleSetEvaluator) MatchEvent(event *gomatrixserverlib.Event) (*Rule, error) {
// TODO: server-default rules have lower priority than user rules,
// but they are stored together with the user rules. It's a bit
// unclear what the specification (11.14.1.4 Predefined rules)
// means the ordering should be.
//
// The most reasonable interpretation is that default overrides
// still have lower priority than user content rules, so we
// iterate twice.
for _, rsat := range rse.ruleSet {
for _, defRules := range []bool{false, true} {
for _, rule := range rsat.Rules {
if rule.Default != defRules {
continue
}
ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec)
if err != nil {
return nil, err
}
if ok {
return rule, nil
}
}
}
}
// No matching rule.
return nil, nil
}
func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) {
if !rule.Enabled {
return false, nil
}
switch kind {
case OverrideKind, UnderrideKind:
for _, cond := range rule.Conditions {
ok, err := conditionMatches(cond, event, ec)
if err != nil {
return false, err
}
if !ok {
return false, nil
}
}
return true, nil
case ContentKind:
// TODO: "These configure behaviour for (unencrypted) messages
// that match certain patterns." - Does that mean "content.body"?
return patternMatches("content.body", rule.Pattern, event)
case RoomKind:
return rule.RuleID == event.RoomID(), nil
case SenderKind:
return rule.RuleID == event.Sender(), nil
default:
return false, nil
}
}
func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) {
switch cond.Kind {
case EventMatchCondition:
return patternMatches(cond.Key, cond.Pattern, event)
case ContainsDisplayNameCondition:
return patternMatches("content.body", ec.UserDisplayName(), event)
case RoomMemberCountCondition:
cmp, err := parseRoomMemberCountCondition(cond.Is)
if err != nil {
return false, fmt.Errorf("parsing room_member_count condition: %w", err)
}
n, err := ec.RoomMemberCount()
if err != nil {
return false, fmt.Errorf("RoomMemberCount failed: %w", err)
}
return cmp(n), nil
case SenderNotificationPermissionCondition:
return ec.HasPowerLevel(event.Sender(), cond.Key)
default:
return false, nil
}
}
func patternMatches(key, pattern string, event *gomatrixserverlib.Event) (bool, error) {
re, err := globToRegexp(pattern)
if err != nil {
return false, err
}
var eventMap map[string]interface{}
if err = json.Unmarshal(event.JSON(), &eventMap); err != nil {
return false, fmt.Errorf("parsing event: %w", err)
}
v, err := lookupMapPath(strings.Split(key, "."), eventMap)
if err != nil {
// An unknown path is a benign error that shouldn't stop rule
// processing. It's just a non-match.
return false, nil
}
return re.MatchString(fmt.Sprint(v)), nil
}

View file

@ -0,0 +1,189 @@
package pushrules
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/matrix-org/gomatrixserverlib"
)
func TestRuleSetEvaluatorMatchEvent(t *testing.T) {
ev := mustEventFromJSON(t, `{}`)
defaultEnabled := &Rule{
RuleID: ".default.enabled",
Default: true,
Enabled: true,
}
userEnabled := &Rule{
RuleID: ".user.enabled",
Default: false,
Enabled: true,
}
userEnabled2 := &Rule{
RuleID: ".user.enabled.2",
Default: false,
Enabled: true,
}
tsts := []struct {
Name string
RuleSet RuleSet
Want *Rule
}{
{"empty", RuleSet{}, nil},
{"defaultCanWin", RuleSet{Override: []*Rule{defaultEnabled}}, defaultEnabled},
{"userWins", RuleSet{Override: []*Rule{defaultEnabled, userEnabled}}, userEnabled},
{"defaultOverrideWins", RuleSet{Override: []*Rule{defaultEnabled}, Underride: []*Rule{userEnabled}}, defaultEnabled},
{"overrideContent", RuleSet{Override: []*Rule{userEnabled}, Content: []*Rule{userEnabled2}}, userEnabled},
{"overrideRoom", RuleSet{Override: []*Rule{userEnabled}, Room: []*Rule{userEnabled2}}, userEnabled},
{"overrideSender", RuleSet{Override: []*Rule{userEnabled}, Sender: []*Rule{userEnabled2}}, userEnabled},
{"overrideUnderride", RuleSet{Override: []*Rule{userEnabled}, Underride: []*Rule{userEnabled2}}, userEnabled},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
rse := NewRuleSetEvaluator(nil, &tst.RuleSet)
got, err := rse.MatchEvent(ev)
if err != nil {
t.Fatalf("MatchEvent failed: %v", err)
}
if diff := cmp.Diff(tst.Want, got); diff != "" {
t.Errorf("MatchEvent rule: +got -want:\n%s", diff)
}
})
}
}
func TestRuleMatches(t *testing.T) {
emptyRule := Rule{Enabled: true}
tsts := []struct {
Name string
Kind Kind
Rule Rule
EventJSON string
Want bool
}{
{"emptyOverride", OverrideKind, emptyRule, `{}`, true},
{"emptyContent", ContentKind, emptyRule, `{}`, false},
{"emptyRoom", RoomKind, emptyRule, `{}`, true},
{"emptySender", SenderKind, emptyRule, `{}`, true},
{"emptyUnderride", UnderrideKind, emptyRule, `{}`, true},
{"disabled", OverrideKind, Rule{}, `{}`, false},
{"overrideConditionMatch", OverrideKind, Rule{Enabled: true}, `{}`, true},
{"overrideConditionNoMatch", OverrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false},
{"underrideConditionMatch", UnderrideKind, Rule{Enabled: true}, `{}`, true},
{"underrideConditionNoMatch", UnderrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false},
{"contentMatch", ContentKind, Rule{Enabled: true, Pattern: "b"}, `{"content":{"body":"abc"}}`, true},
{"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: "d"}, `{"content":{"body":"abc"}}`, false},
{"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true},
{"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false},
{"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@user@example.com"}`, true},
{"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@otheruser@example.com"}`, false},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil)
if err != nil {
t.Fatalf("ruleMatches failed: %v", err)
}
if got != tst.Want {
t.Errorf("ruleMatches: got %v, want %v", got, tst.Want)
}
})
}
}
func TestConditionMatches(t *testing.T) {
tsts := []struct {
Name string
Cond Condition
EventJSON string
Want bool
}{
{"empty", Condition{}, `{}`, false},
{"empty", Condition{Kind: "unknownstring"}, `{}`, false},
{"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, true},
{"displayNameNoMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"something without displayname"}}`, false},
{"displayNameMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"hello Dear User, how are you?"}}`, true},
{"roomMemberCountLessNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<2"}, `{}`, false},
{"roomMemberCountLessMatch", Condition{Kind: RoomMemberCountCondition, Is: "<3"}, `{}`, true},
{"roomMemberCountLessEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, `{}`, false},
{"roomMemberCountLessEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, `{}`, true},
{"roomMemberCountEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "==1"}, `{}`, false},
{"roomMemberCountEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "==2"}, `{}`, true},
{"roomMemberCountGreaterEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, `{}`, false},
{"roomMemberCountGreaterEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, `{}`, true},
{"roomMemberCountGreaterNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">2"}, `{}`, false},
{"roomMemberCountGreaterMatch", Condition{Kind: RoomMemberCountCondition, Is: ">1"}, `{}`, true},
{"senderNotificationPermissionMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@poweruser:example.com"}`, true},
{"senderNotificationPermissionNoMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@nobody:example.com"}`, false},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
got, err := conditionMatches(&tst.Cond, mustEventFromJSON(t, tst.EventJSON), &fakeEvaluationContext{})
if err != nil {
t.Fatalf("conditionMatches failed: %v", err)
}
if got != tst.Want {
t.Errorf("conditionMatches: got %v, want %v", got, tst.Want)
}
})
}
}
type fakeEvaluationContext struct{}
func (fakeEvaluationContext) UserDisplayName() string { return "Dear User" }
func (fakeEvaluationContext) RoomMemberCount() (int, error) { return 2, nil }
func (fakeEvaluationContext) HasPowerLevel(userID, levelKey string) (bool, error) {
return userID == "@poweruser:example.com" && levelKey == "powerlevel", nil
}
func TestPatternMatches(t *testing.T) {
tsts := []struct {
Name string
Key string
Pattern string
EventJSON string
Want bool
}{
{"empty", "", "", `{}`, false},
// Note that an empty pattern contains no wildcard characters,
// which implicitly means "*".
{"patternEmpty", "content", "", `{"content":{}}`, true},
{"literal", "content.creator", "acreator", `{"content":{"creator":"acreator"}}`, true},
{"substring", "content.creator", "reat", `{"content":{"creator":"acreator"}}`, true},
{"singlePattern", "content.creator", "acr?ator", `{"content":{"creator":"acreator"}}`, true},
{"multiPattern", "content.creator", "a*ea*r", `{"content":{"creator":"acreator"}}`, true},
{"patternNoSubstring", "content.creator", "r*t", `{"content":{"creator":"acreator"}}`, false},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
got, err := patternMatches(tst.Key, tst.Pattern, mustEventFromJSON(t, tst.EventJSON))
if err != nil {
t.Fatalf("patternMatches failed: %v", err)
}
if got != tst.Want {
t.Errorf("patternMatches: got %v, want %v", got, tst.Want)
}
})
}
}
func mustEventFromJSON(t *testing.T, json string) *gomatrixserverlib.Event {
ev, err := gomatrixserverlib.NewEventFromTrustedJSON([]byte(json), false, gomatrixserverlib.RoomVersionV7)
if err != nil {
t.Fatal(err)
}
return ev
}

View file

@ -0,0 +1,71 @@
package pushrules
// An AccountRuleSets carries the rule sets associated with an
// account.
type AccountRuleSets struct {
Global RuleSet `json:"global"` // Required
}
// A RuleSet contains all the various push rules for an
// account. Listed in decreasing order of priority.
type RuleSet struct {
Override []*Rule `json:"override,omitempty"`
Content []*Rule `json:"content,omitempty"`
Room []*Rule `json:"room,omitempty"`
Sender []*Rule `json:"sender,omitempty"`
Underride []*Rule `json:"underride,omitempty"`
}
// A Rule contains matchers, conditions and final actions. While
// evaluating, at most one rule is considered matching.
//
// Kind and scope are part of the push rules request/responses, but
// not of the core data model.
type Rule struct {
// RuleID is either a free identifier, or the sender's MXID for
// SenderKind. Required.
RuleID string `json:"rule_id"`
// Default indicates whether this is a server-defined default, or
// a user-provided rule. Required.
//
// The server-default rules have the lowest priority.
Default bool `json:"default"`
// Enabled allows the user to disable rules while keeping them
// around. Required.
Enabled bool `json:"enabled"`
// Actions describe the desired outcome, should the rule
// match. Required.
Actions []*Action `json:"actions"`
// Conditions provide the rule's conditions for OverrideKind and
// UnderrideKind. Not allowed for other kinds.
Conditions []*Condition `json:"conditions"`
// Pattern is the body pattern to match for ContentKind. Required
// for that kind. The interpretation is the same as that of
// Condition.Pattern.
Pattern string `json:"pattern"`
}
// Scope only has one valid value. See also AccountRuleSets.
type Scope string
const (
UnknownScope Scope = ""
GlobalScope Scope = "global"
)
// Kind is the type of push rule. See also RuleSet.
type Kind string
const (
UnknownKind Kind = ""
OverrideKind Kind = "override"
ContentKind Kind = "content"
RoomKind Kind = "room"
SenderKind Kind = "sender"
UnderrideKind Kind = "underride"
)

125
internal/pushrules/util.go Normal file
View file

@ -0,0 +1,125 @@
package pushrules
import (
"fmt"
"regexp"
"strconv"
"strings"
)
// ActionsToTweaks converts a list of actions into a primary action
// kind and a tweaks map. Returns a nil map if it would have been
// empty.
func ActionsToTweaks(as []*Action) (ActionKind, map[string]interface{}, error) {
var kind ActionKind
tweaks := map[string]interface{}{}
for _, a := range as {
if a.Kind == SetTweakAction {
tweaks[string(a.Tweak)] = a.Value
continue
}
if kind != UnknownAction {
return UnknownAction, nil, fmt.Errorf("got multiple primary actions: already had %q, got %s", kind, a.Kind)
}
kind = a.Kind
}
if len(tweaks) == 0 {
tweaks = nil
}
return kind, tweaks, nil
}
// BoolTweakOr returns the named tweak as a boolean, and returns `def`
// on failure.
func BoolTweakOr(tweaks map[string]interface{}, key TweakKey, def bool) bool {
v, ok := tweaks[string(key)]
if !ok {
return def
}
b, ok := v.(bool)
if !ok {
return def
}
return b
}
// globToRegexp converts a Matrix glob-style pattern to a Regular expression.
func globToRegexp(pattern string) (*regexp.Regexp, error) {
// TODO: It's unclear which glob characters are supported. The only
// place this is discussed is for the unrelated "m.policy.rule.*"
// events. Assuming, the same: /[*?]/
if !strings.ContainsAny(pattern, "*?") {
pattern = "*" + pattern + "*"
}
// The defined syntax doesn't allow escaping the glob wildcard
// characters, which makes this a straight-forward
// replace-after-quote.
pattern = globNonMetaRegexp.ReplaceAllStringFunc(pattern, regexp.QuoteMeta)
pattern = strings.Replace(pattern, "*", ".*", -1)
pattern = strings.Replace(pattern, "?", ".", -1)
return regexp.Compile("^(" + pattern + ")$")
}
// globNonMetaRegexp are the characters that are not considered glob
// meta-characters (i.e. may need escaping).
var globNonMetaRegexp = regexp.MustCompile("[^*?]+")
// lookupMapPath traverses a hierarchical map structure, like the one
// produced by json.Unmarshal, to return the leaf value. Traversing
// arrays/slices is not supported, only objects/maps.
func lookupMapPath(path []string, m map[string]interface{}) (interface{}, error) {
if len(path) == 0 {
return nil, fmt.Errorf("empty path")
}
var v interface{} = m
for i, key := range path {
m, ok := v.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("expected an object for path %q, but got %T", strings.Join(path[:i+1], "."), v)
}
v, ok = m[key]
if !ok {
return nil, fmt.Errorf("path not found: %s", strings.Join(path[:i+1], "."))
}
}
return v, nil
}
// parseRoomMemberCountCondition parses a string like "2", "==2", "<2"
// into a function that checks if the argument to it fulfils the
// condition.
func parseRoomMemberCountCondition(s string) (func(int) bool, error) {
var b int
var cmp = func(a int) bool { return a == b }
switch {
case strings.HasPrefix(s, "<="):
cmp = func(a int) bool { return a <= b }
s = s[2:]
case strings.HasPrefix(s, ">="):
cmp = func(a int) bool { return a >= b }
s = s[2:]
case strings.HasPrefix(s, "<"):
cmp = func(a int) bool { return a < b }
s = s[1:]
case strings.HasPrefix(s, ">"):
cmp = func(a int) bool { return a > b }
s = s[1:]
case strings.HasPrefix(s, "=="):
// Same cmp as the default.
s = s[2:]
}
v, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return nil, err
}
b = int(v)
return cmp, nil
}

View file

@ -0,0 +1,169 @@
package pushrules
import (
"fmt"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestActionsToTweaks(t *testing.T) {
tsts := []struct {
Name string
Input []*Action
WantKind ActionKind
WantTweaks map[string]interface{}
}{
{"empty", nil, UnknownAction, nil},
{"zero", []*Action{{}}, UnknownAction, nil},
{"onlyPrimary", []*Action{{Kind: NotifyAction}}, NotifyAction, nil},
{"onlyTweak", []*Action{{Kind: SetTweakAction, Tweak: HighlightTweak}}, UnknownAction, map[string]interface{}{"highlight": nil}},
{"onlyTweakWithValue", []*Action{{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}}, UnknownAction, map[string]interface{}{"sound": "default"}},
{
"all",
[]*Action{
{Kind: CoalesceAction},
{Kind: SetTweakAction, Tweak: HighlightTweak},
{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"},
},
CoalesceAction,
map[string]interface{}{"highlight": nil, "sound": "default"},
},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
gotKind, gotTweaks, err := ActionsToTweaks(tst.Input)
if err != nil {
t.Fatalf("ActionsToTweaks failed: %v", err)
}
if gotKind != tst.WantKind {
t.Errorf("kind: got %v, want %v", gotKind, tst.WantKind)
}
if diff := cmp.Diff(tst.WantTweaks, gotTweaks); diff != "" {
t.Errorf("tweaks: +got -want:\n%s", diff)
}
})
}
}
func TestBoolTweakOr(t *testing.T) {
tsts := []struct {
Name string
Input map[string]interface{}
Def bool
Want bool
}{
{"nil", nil, false, false},
{"nilValue", map[string]interface{}{"highlight": nil}, true, true},
{"false", map[string]interface{}{"highlight": false}, true, false},
{"true", map[string]interface{}{"highlight": true}, false, true},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
got := BoolTweakOr(tst.Input, HighlightTweak, tst.Def)
if got != tst.Want {
t.Errorf("BoolTweakOr: got %v, want %v", got, tst.Want)
}
})
}
}
func TestGlobToRegexp(t *testing.T) {
tsts := []struct {
Input string
Want string
}{
{"", "^(.*.*)$"},
{"a", "^(.*a.*)$"},
{"a.b", "^(.*a\\.b.*)$"},
{"a?b", "^(a.b)$"},
{"a*b*", "^(a.*b.*)$"},
{"a*b?", "^(a.*b.)$"},
}
for _, tst := range tsts {
t.Run(tst.Want, func(t *testing.T) {
got, err := globToRegexp(tst.Input)
if err != nil {
t.Fatalf("globToRegexp failed: %v", err)
}
if got.String() != tst.Want {
t.Errorf("got %v, want %v", got.String(), tst.Want)
}
})
}
}
func TestLookupMapPath(t *testing.T) {
tsts := []struct {
Path []string
Root map[string]interface{}
Want interface{}
}{
{[]string{"a"}, map[string]interface{}{"a": "b"}, "b"},
{[]string{"a"}, map[string]interface{}{"a": 42}, 42},
{[]string{"a", "b"}, map[string]interface{}{"a": map[string]interface{}{"b": "c"}}, "c"},
}
for _, tst := range tsts {
t.Run(fmt.Sprint(tst.Path, "/", tst.Want), func(t *testing.T) {
got, err := lookupMapPath(tst.Path, tst.Root)
if err != nil {
t.Fatalf("lookupMapPath failed: %v", err)
}
if diff := cmp.Diff(tst.Want, got); diff != "" {
t.Errorf("+got -want:\n%s", diff)
}
})
}
}
func TestLookupMapPathInvalid(t *testing.T) {
tsts := []struct {
Path []string
Root map[string]interface{}
}{
{nil, nil},
{[]string{"a"}, nil},
{[]string{"a", "b"}, map[string]interface{}{"a": "c"}},
}
for _, tst := range tsts {
t.Run(fmt.Sprint(tst.Path), func(t *testing.T) {
got, err := lookupMapPath(tst.Path, tst.Root)
if err == nil {
t.Fatalf("lookupMapPath succeeded with %#v, but want failure", got)
}
})
}
}
func TestParseRoomMemberCountCondition(t *testing.T) {
tsts := []struct {
Input string
WantTrue []int
WantFalse []int
}{
{"1", []int{1}, []int{0, 2}},
{"==1", []int{1}, []int{0, 2}},
{"<1", []int{0}, []int{1, 2}},
{"<=1", []int{0, 1}, []int{2}},
{">1", []int{2}, []int{0, 1}},
{">=42", []int{42, 43}, []int{41}},
}
for _, tst := range tsts {
t.Run(tst.Input, func(t *testing.T) {
got, err := parseRoomMemberCountCondition(tst.Input)
if err != nil {
t.Fatalf("parseRoomMemberCountCondition failed: %v", err)
}
for _, v := range tst.WantTrue {
if !got(v) {
t.Errorf("parseRoomMemberCountCondition(%q)(%d): got false, want true", tst.Input, v)
}
}
for _, v := range tst.WantFalse {
if got(v) {
t.Errorf("parseRoomMemberCountCondition(%q)(%d): got true, want false", tst.Input, v)
}
}
})
}
}

View file

@ -0,0 +1,85 @@
package pushrules
import (
"fmt"
"regexp"
)
// ValidateRule checks the rule for errors. These follow from Sytests
// and the specification.
func ValidateRule(kind Kind, rule *Rule) []error {
var errs []error
if !validRuleIDRE.MatchString(rule.RuleID) {
errs = append(errs, fmt.Errorf("invalid rule ID: %s", rule.RuleID))
}
if len(rule.Actions) == 0 {
errs = append(errs, fmt.Errorf("missing actions"))
}
for _, action := range rule.Actions {
errs = append(errs, validateAction(action)...)
}
for _, cond := range rule.Conditions {
errs = append(errs, validateCondition(cond)...)
}
switch kind {
case OverrideKind, UnderrideKind:
// The empty list is allowed, but for JSON-encoding reasons,
// it must not be nil.
if rule.Conditions == nil {
errs = append(errs, fmt.Errorf("missing rule conditions"))
}
case ContentKind:
if rule.Pattern == "" {
errs = append(errs, fmt.Errorf("missing content rule pattern"))
}
case RoomKind, SenderKind:
// Do nothing.
default:
errs = append(errs, fmt.Errorf("invalid rule kind: %s", kind))
}
return errs
}
// validRuleIDRE is a regexp for valid IDs.
//
// TODO: the specification doesn't seem to say what the rule ID syntax
// is. A Sytest fails if it contains a backslash.
var validRuleIDRE = regexp.MustCompile(`^([^\\]+)$`)
// validateAction returns issues with an Action.
func validateAction(action *Action) []error {
var errs []error
switch action.Kind {
case NotifyAction, DontNotifyAction, CoalesceAction, SetTweakAction:
// Do nothing.
default:
errs = append(errs, fmt.Errorf("invalid rule action kind: %s", action.Kind))
}
return errs
}
// validateCondition returns issues with a Condition.
func validateCondition(cond *Condition) []error {
var errs []error
switch cond.Kind {
case EventMatchCondition, ContainsDisplayNameCondition, RoomMemberCountCondition, SenderNotificationPermissionCondition:
// Do nothing.
default:
errs = append(errs, fmt.Errorf("invalid rule condition kind: %s", cond.Kind))
}
return errs
}

View file

@ -0,0 +1,163 @@
package pushrules
import (
"strings"
"testing"
)
func TestValidateRuleNegatives(t *testing.T) {
tsts := []struct {
Name string
Kind Kind
Rule Rule
WantErrString string
}{
{"emptyRuleID", OverrideKind, Rule{}, "invalid rule ID"},
{"invalidKind", Kind("something else"), Rule{}, "invalid rule kind"},
{"ruleIDBackslash", OverrideKind, Rule{RuleID: "#foo\\:example.com"}, "invalid rule ID"},
{"noActions", OverrideKind, Rule{}, "missing actions"},
{"invalidAction", OverrideKind, Rule{Actions: []*Action{{}}}, "invalid rule action kind"},
{"invalidCondition", OverrideKind, Rule{Conditions: []*Condition{{}}}, "invalid rule condition kind"},
{"overrideNoCondition", OverrideKind, Rule{}, "missing rule conditions"},
{"underrideNoCondition", UnderrideKind, Rule{}, "missing rule conditions"},
{"contentNoPattern", ContentKind, Rule{}, "missing content rule pattern"},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
errs := ValidateRule(tst.Kind, &tst.Rule)
var foundErr error
for _, err := range errs {
t.Logf("Got error %#v", err)
if strings.Contains(err.Error(), tst.WantErrString) {
foundErr = err
}
}
if foundErr == nil {
t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString)
}
})
}
}
func TestValidateRulePositives(t *testing.T) {
tsts := []struct {
Name string
Kind Kind
Rule Rule
WantNoErrString string
}{
{"invalidKind", OverrideKind, Rule{}, "invalid rule kind"},
{"invalidActionNoActions", OverrideKind, Rule{}, "invalid rule action kind"},
{"invalidConditionNoConditions", OverrideKind, Rule{}, "invalid rule condition kind"},
{"contentNoCondition", ContentKind, Rule{}, "missing rule conditions"},
{"roomNoCondition", RoomKind, Rule{}, "missing rule conditions"},
{"senderNoCondition", SenderKind, Rule{}, "missing rule conditions"},
{"overrideNoPattern", OverrideKind, Rule{}, "missing content rule pattern"},
{"overrideEmptyConditions", OverrideKind, Rule{Conditions: []*Condition{}}, "missing rule conditions"},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
errs := ValidateRule(tst.Kind, &tst.Rule)
for _, err := range errs {
t.Logf("Got error %#v", err)
if strings.Contains(err.Error(), tst.WantNoErrString) {
t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString)
}
}
})
}
}
func TestValidateActionNegatives(t *testing.T) {
tsts := []struct {
Name string
Action Action
WantErrString string
}{
{"emptyKind", Action{}, "invalid rule action kind"},
{"invalidKind", Action{Kind: ActionKind("something else")}, "invalid rule action kind"},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
errs := validateAction(&tst.Action)
var foundErr error
for _, err := range errs {
t.Logf("Got error %#v", err)
if strings.Contains(err.Error(), tst.WantErrString) {
foundErr = err
}
}
if foundErr == nil {
t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString)
}
})
}
}
func TestValidateActionPositives(t *testing.T) {
tsts := []struct {
Name string
Action Action
WantNoErrString string
}{
{"invalidKind", Action{Kind: NotifyAction}, "invalid rule action kind"},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
errs := validateAction(&tst.Action)
for _, err := range errs {
t.Logf("Got error %#v", err)
if strings.Contains(err.Error(), tst.WantNoErrString) {
t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString)
}
}
})
}
}
func TestValidateConditionNegatives(t *testing.T) {
tsts := []struct {
Name string
Condition Condition
WantErrString string
}{
{"emptyKind", Condition{}, "invalid rule condition kind"},
{"invalidKind", Condition{Kind: ConditionKind("something else")}, "invalid rule condition kind"},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
errs := validateCondition(&tst.Condition)
var foundErr error
for _, err := range errs {
t.Logf("Got error %#v", err)
if strings.Contains(err.Error(), tst.WantErrString) {
foundErr = err
}
}
if foundErr == nil {
t.Errorf("errs: got %#v, want containing %q", errs, tst.WantErrString)
}
})
}
}
func TestValidateConditionPositives(t *testing.T) {
tsts := []struct {
Name string
Condition Condition
WantNoErrString string
}{
{"invalidKind", Condition{Kind: EventMatchCondition}, "invalid rule condition kind"},
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
errs := validateCondition(&tst.Condition)
for _, err := range errs {
t.Logf("Got error %#v", err)
if strings.Contains(err.Error(), tst.WantNoErrString) {
t.Errorf("errs: got %#v, want none containing %q", errs, tst.WantNoErrString)
}
}
})
}
}

View file

@ -163,6 +163,7 @@ type StatementList []struct {
func (s StatementList) Prepare(db *sql.DB) (err error) { func (s StatementList) Prepare(db *sql.DB) (err error) {
for _, statement := range s { for _, statement := range s {
if *statement.Statement, err = db.Prepare(statement.SQL); err != nil { if *statement.Statement, err = db.Prepare(statement.SQL); err != nil {
err = fmt.Errorf("Error %q while preparing statement: %s", err, statement.SQL)
return return
} }
} }

BIN
q.sqlite Normal file

Binary file not shown.

63
run-sytest.sh Executable file
View file

@ -0,0 +1,63 @@
#!/bin/bash
#
# Runs SyTest either from Docker Hub, or from ../sytest. If it's run
# locally, the Docker image is rebuilt first.
#
# Logs are stored in ../sytestout/logs.
set -e
set -o pipefail
main() {
local tag=buster
local base_image=debian:$tag
local runargs=()
cd "$(dirname "$0")"
if [ -d ../sytest ]; then
local tmpdir
tmpdir="$(mktemp -d --tmpdir run-systest.XXXXXXXXXX)"
trap "rm -r '$tmpdir'" EXIT
if [ -z "$DISABLE_BUILDING_SYTEST" ]; then
echo "Re-building ../sytest Docker images..."
local status
(
cd ../sytest
docker build -f docker/base.Dockerfile --build-arg BASE_IMAGE="$base_image" --tag matrixdotorg/sytest:"$tag" .
docker build -f docker/dendrite.Dockerfile --build-arg SYTEST_IMAGE_TAG="$tag" --tag matrixdotorg/sytest-dendrite:latest .
) &>"$tmpdir/buildlog" || status=$?
if (( status != 0 )); then
# Docker is very verbose, and we don't really care about
# building SyTest. So we accumulate and only output on
# failure.
cat "$tmpdir/buildlog" >&2
return $status
fi
fi
runargs+=( -v "$PWD/../sytest:/sytest:ro" )
fi
if [ -n "$SYTEST_POSTGRES" ]; then
runargs+=( -e POSTGRES=1 )
fi
local sytestout=$PWD/../sytestout
mkdir -p "$sytestout"/{logs,cache/go-build,cache/go-pkg}
docker run \
--rm \
--name "sytest-dendrite-${LOGNAME}" \
-e LOGS_USER=$(id -u) \
-e LOGS_GROUP=$(id -g) \
-v "$PWD:/src/:ro" \
-v "$sytestout/logs:/logs/" \
-v "$sytestout/cache/go-build:/root/.cache/go-build" \
-v "$sytestout/cache/go-pkg:/gopath/pkg" \
"${runargs[@]}" \
matrixdotorg/sytest-dendrite:latest "$@"
}
main "$@"

View file

@ -30,6 +30,7 @@ import (
sentryhttp "github.com/getsentry/sentry-go/http" sentryhttp "github.com/getsentry/sentry-go/http"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"go.uber.org/atomic" "go.uber.org/atomic"
@ -271,6 +272,11 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI {
return f return f
} }
// PushGatewayHTTPClient returns a new client for interacting with (external) Push Gateways.
func (b *BaseDendrite) PushGatewayHTTPClient() pushgateway.Client {
return pushgateway.NewHTTPClient(b.Cfg.UserAPI.PushGatewayDisableTLSValidation)
}
// CreateAccountsDB creates a new instance of the accounts database. Should only // CreateAccountsDB creates a new instance of the accounts database. Should only
// be called once per component. // be called once per component.
func (b *BaseDendrite) CreateAccountsDB() userdb.Database { func (b *BaseDendrite) CreateAccountsDB() userdb.Database {

View file

@ -205,6 +205,11 @@ user_api:
max_open_conns: 100 max_open_conns: 100
max_idle_conns: 2 max_idle_conns: 2
conn_max_lifetime: -1 conn_max_lifetime: -1
pusher_database:
connection_string: file:pushserver.db
max_open_conns: 100
max_idle_conns: 2
conn_max_lifetime: -1
tracing: tracing:
enabled: false enabled: false
jaeger: jaeger:

View file

@ -13,6 +13,9 @@ type UserAPI struct {
// The length of time an OpenID token is condidered valid in milliseconds // The length of time an OpenID token is condidered valid in milliseconds
OpenIDTokenLifetimeMS int64 `yaml:"openid_token_lifetime_ms"` OpenIDTokenLifetimeMS int64 `yaml:"openid_token_lifetime_ms"`
// Disable TLS validation on HTTPS calls to push gatways. NOT RECOMMENDED!
PushGatewayDisableTLSValidation bool `yaml:"push_gateway_disable_tls_validation"`
// The Account database stores the login details and account information // The Account database stores the login details and account information
// for local users. It is accessed by the UserAPI. // for local users. It is accessed by the UserAPI.
AccountDatabase DatabaseOptions `yaml:"account_database"` AccountDatabase DatabaseOptions `yaml:"account_database"`

View file

@ -18,7 +18,10 @@ var (
OutputKeyChangeEvent = "OutputKeyChangeEvent" OutputKeyChangeEvent = "OutputKeyChangeEvent"
OutputTypingEvent = "OutputTypingEvent" OutputTypingEvent = "OutputTypingEvent"
OutputClientData = "OutputClientData" OutputClientData = "OutputClientData"
OutputNotificationData = "OutputNotificationData"
OutputReceiptEvent = "OutputReceiptEvent" OutputReceiptEvent = "OutputReceiptEvent"
OutputStreamEvent = "OutputStreamEvent"
OutputReadUpdate = "OutputReadUpdate"
) )
var streams = []*nats.StreamConfig{ var streams = []*nats.StreamConfig{
@ -58,4 +61,19 @@ var streams = []*nats.StreamConfig{
Retention: nats.InterestPolicy, Retention: nats.InterestPolicy,
Storage: nats.FileStorage, Storage: nats.FileStorage,
}, },
{
Name: OutputNotificationData,
Retention: nats.InterestPolicy,
Storage: nats.FileStorage,
},
{
Name: OutputStreamEvent,
Retention: nats.InterestPolicy,
Storage: nats.FileStorage,
},
{
Name: OutputReadUpdate,
Retention: nats.InterestPolicy,
Storage: nats.FileStorage,
},
} }

View file

@ -60,8 +60,8 @@ func (m *Monolith) AddAllPublicRoutes(process *process.ProcessContext, csMux, ss
csMux, synapseMux, &m.Config.ClientAPI, m.AccountDB, csMux, synapseMux, &m.Config.ClientAPI, m.AccountDB,
m.FedClient, m.RoomserverAPI, m.FedClient, m.RoomserverAPI,
m.EDUInternalAPI, m.AppserviceAPI, transactions.New(), m.EDUInternalAPI, m.AppserviceAPI, transactions.New(),
m.FederationAPI, m.UserAPI, m.KeyAPI, m.ExtPublicRoomsProvider, m.FederationAPI, m.UserAPI, m.KeyAPI,
&m.Config.MSCs, m.ExtPublicRoomsProvider, &m.Config.MSCs,
) )
federationapi.AddPublicRoutes( federationapi.AddPublicRoutes(
ssMux, keyMux, wkMux, &m.Config.FederationAPI, m.UserAPI, m.FedClient, ssMux, keyMux, wkMux, &m.Config.FederationAPI, m.UserAPI, m.FedClient,

View file

@ -17,6 +17,7 @@ package consumers
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
@ -24,9 +25,12 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -39,6 +43,8 @@ type OutputClientDataConsumer struct {
db storage.Database db storage.Database
stream types.StreamProvider stream types.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
serverName gomatrixserverlib.ServerName
producer *producers.UserAPIReadProducer
} }
// NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers. // NewOutputClientDataConsumer creates a new OutputClientData consumer. Call Start() to begin consuming from room servers.
@ -49,6 +55,7 @@ func NewOutputClientDataConsumer(
store storage.Database, store storage.Database,
notifier *notifier.Notifier, notifier *notifier.Notifier,
stream types.StreamProvider, stream types.StreamProvider,
producer *producers.UserAPIReadProducer,
) *OutputClientDataConsumer { ) *OutputClientDataConsumer {
return &OutputClientDataConsumer{ return &OutputClientDataConsumer{
ctx: process.Context(), ctx: process.Context(),
@ -58,6 +65,8 @@ func NewOutputClientDataConsumer(
db: store, db: store,
notifier: notifier, notifier: notifier,
stream: stream, stream: stream,
serverName: cfg.Matrix.ServerName,
producer: producer,
} }
} }
@ -100,8 +109,48 @@ func (s *OutputClientDataConsumer) onMessage(ctx context.Context, msg *nats.Msg)
}).Panicf("could not save account data") }).Panicf("could not save account data")
} }
if err = s.sendReadUpdate(ctx, userID, output); err != nil {
log.WithError(err).WithFields(logrus.Fields{
"user_id": userID,
"room_id": output.RoomID,
}).Errorf("Failed to generate read update")
sentry.CaptureException(err)
return false
}
s.stream.Advance(streamPos) s.stream.Advance(streamPos)
s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos}) s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos})
return true return true
} }
func (s *OutputClientDataConsumer) sendReadUpdate(ctx context.Context, userID string, output eventutil.AccountData) error {
if output.Type != "m.fully_read" || output.ReadMarker == nil {
return nil
}
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
if serverName != s.serverName {
return nil
}
var readPos types.StreamPosition
var fullyReadPos types.StreamPosition
if output.ReadMarker.Read != "" {
if _, readPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.Read); err != nil {
return fmt.Errorf("s.db.PositionInTopology (Read): %w", err)
}
}
if output.ReadMarker.FullyRead != "" {
if _, fullyReadPos, err = s.db.PositionInTopology(ctx, output.ReadMarker.FullyRead); err != nil {
return fmt.Errorf("s.db.PositionInTopology (FullyRead): %w", err)
}
}
if readPos > 0 || fullyReadPos > 0 {
if err := s.producer.SendReadUpdate(userID, output.RoomID, readPos, fullyReadPos); err != nil {
return fmt.Errorf("s.producer.SendReadUpdate: %w", err)
}
}
return nil
}

View file

@ -17,6 +17,7 @@ package consumers
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/eduserver/api"
@ -24,9 +25,12 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -39,6 +43,8 @@ type OutputReceiptEventConsumer struct {
db storage.Database db storage.Database
stream types.StreamProvider stream types.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
serverName gomatrixserverlib.ServerName
producer *producers.UserAPIReadProducer
} }
// NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer. // NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer.
@ -50,6 +56,7 @@ func NewOutputReceiptEventConsumer(
store storage.Database, store storage.Database,
notifier *notifier.Notifier, notifier *notifier.Notifier,
stream types.StreamProvider, stream types.StreamProvider,
producer *producers.UserAPIReadProducer,
) *OutputReceiptEventConsumer { ) *OutputReceiptEventConsumer {
return &OutputReceiptEventConsumer{ return &OutputReceiptEventConsumer{
ctx: process.Context(), ctx: process.Context(),
@ -59,6 +66,8 @@ func NewOutputReceiptEventConsumer(
db: store, db: store,
notifier: notifier, notifier: notifier,
stream: stream, stream: stream,
serverName: cfg.Matrix.ServerName,
producer: producer,
} }
} }
@ -92,8 +101,42 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msg *nats.Ms
return true return true
} }
if err = s.sendReadUpdate(ctx, output); err != nil {
log.WithError(err).WithFields(logrus.Fields{
"user_id": output.UserID,
"room_id": output.RoomID,
}).Errorf("Failed to generate read update")
sentry.CaptureException(err)
return false
}
s.stream.Advance(streamPos) s.stream.Advance(streamPos)
s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos})
return true return true
} }
func (s *OutputReceiptEventConsumer) sendReadUpdate(ctx context.Context, output api.OutputReceiptEvent) error {
if output.Type != "m.read" {
return nil
}
_, serverName, err := gomatrixserverlib.SplitID('@', output.UserID)
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
if serverName != s.serverName {
return nil
}
var readPos types.StreamPosition
if output.EventID != "" {
if _, readPos, err = s.db.PositionInTopology(ctx, output.EventID); err != nil {
return fmt.Errorf("s.db.PositionInTopology (Read): %w", err)
}
}
if readPos > 0 {
if err := s.producer.SendReadUpdate(output.UserID, output.RoomID, readPos, 0); err != nil {
return fmt.Errorf("s.producer.SendReadUpdate: %w", err)
}
}
return nil
}

View file

@ -26,6 +26,7 @@ import (
"github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -45,6 +46,7 @@ type OutputRoomEventConsumer struct {
pduStream types.StreamProvider pduStream types.StreamProvider
inviteStream types.StreamProvider inviteStream types.StreamProvider
notifier *notifier.Notifier notifier *notifier.Notifier
producer *producers.UserAPIStreamEventProducer
} }
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
@ -57,6 +59,7 @@ func NewOutputRoomEventConsumer(
pduStream types.StreamProvider, pduStream types.StreamProvider,
inviteStream types.StreamProvider, inviteStream types.StreamProvider,
rsAPI api.RoomserverInternalAPI, rsAPI api.RoomserverInternalAPI,
producer *producers.UserAPIStreamEventProducer,
) *OutputRoomEventConsumer { ) *OutputRoomEventConsumer {
return &OutputRoomEventConsumer{ return &OutputRoomEventConsumer{
ctx: process.Context(), ctx: process.Context(),
@ -69,6 +72,7 @@ func NewOutputRoomEventConsumer(
pduStream: pduStream, pduStream: pduStream,
inviteStream: inviteStream, inviteStream: inviteStream,
rsAPI: rsAPI, rsAPI: rsAPI,
producer: producer,
} }
} }
@ -194,6 +198,12 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
return nil return nil
} }
if err = s.producer.SendStreamEvent(ev.RoomID(), ev, pduPos); err != nil {
log.WithError(err).Errorf("Failed to send stream output event for event %s", ev.EventID())
sentry.CaptureException(err)
return err
}
if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil { if pduPos, err = s.notifyJoinedPeeks(ctx, ev, pduPos); err != nil {
log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos) log.WithError(err).Errorf("Failed to notifyJoinedPeeks for PDU pos %d", pduPos)
sentry.CaptureException(err) sentry.CaptureException(err)

View file

@ -0,0 +1,110 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package consumers
import (
"context"
"encoding/json"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
// OutputNotificationDataConsumer consumes events that originated in
// the Push server.
type OutputNotificationDataConsumer struct {
ctx context.Context
jetstream nats.JetStreamContext
durable string
topic string
db storage.Database
notifier *notifier.Notifier
stream types.StreamProvider
}
// NewOutputNotificationDataConsumer creates a new consumer. Call
// Start() to begin consuming.
func NewOutputNotificationDataConsumer(
process *process.ProcessContext,
cfg *config.SyncAPI,
js nats.JetStreamContext,
store storage.Database,
notifier *notifier.Notifier,
stream types.StreamProvider,
) *OutputNotificationDataConsumer {
s := &OutputNotificationDataConsumer{
ctx: process.Context(),
jetstream: js,
durable: cfg.Matrix.JetStream.Durable("SyncAPINotificationDataConsumer"),
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputNotificationData),
db: store,
notifier: notifier,
stream: stream,
}
return s
}
// Start starts consumption.
func (s *OutputNotificationDataConsumer) Start() error {
return jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
)
}
// onMessage is called when the Sync server receives a new event from
// the push server. It is not safe for this function to be called from
// multiple goroutines, or else the sync stream position may race and
// be incorrectly calculated.
func (s *OutputNotificationDataConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
userID := string(msg.Header.Get(jetstream.UserID))
// Parse out the event JSON
var data eventutil.NotificationData
if err := json.Unmarshal(msg.Data, &data); err != nil {
sentry.CaptureException(err)
log.WithField("user_id", userID).WithError(err).Error("user API consumer: message parse failure")
return true
}
streamPos, err := s.db.UpsertRoomUnreadNotificationCounts(ctx, userID, data.RoomID, data.UnreadNotificationCount, data.UnreadHighlightCount)
if err != nil {
sentry.CaptureException(err)
log.WithFields(log.Fields{
"user_id": userID,
"room_id": data.RoomID,
}).WithError(err).Error("Could not save notification counts")
return false
}
s.stream.Advance(streamPos)
s.notifier.OnNewNotificationData(userID, types.StreamingToken{NotificationDataPosition: streamPos})
log.WithFields(log.Fields{
"user_id": userID,
"room_id": data.RoomID,
"streamPos": streamPos,
}).Trace("Received notification data from user API")
return true
}

View file

@ -217,6 +217,17 @@ func (n *Notifier) OnNewInvite(
n.wakeupUsers([]string{wakeUserID}, nil, n.currPos) n.wakeupUsers([]string{wakeUserID}, nil, n.currPos)
} }
func (n *Notifier) OnNewNotificationData(
userID string,
posUpdate types.StreamingToken,
) {
n.streamLock.Lock()
defer n.streamLock.Unlock()
n.currPos.ApplyUpdates(posUpdate)
n.wakeupUsers([]string{userID}, nil, n.currPos)
}
// GetListener returns a UserStreamListener that can be used to wait for // GetListener returns a UserStreamListener that can be used to wait for
// updates for a user. Must be closed. // updates for a user. Must be closed.
// notify for anything before sincePos // notify for anything before sincePos

View file

@ -219,7 +219,7 @@ func TestEDUWakeup(t *testing.T) {
go func() { go func() {
pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter)) pos, err := waitForEvents(n, newTestSyncRequest(bob, bobDev, syncPositionAfter))
if err != nil { if err != nil {
t.Errorf("TestNewInviteEventForUser error: %w", err) t.Errorf("TestNewInviteEventForUser error: %v", err)
} }
mustEqualPositions(t, pos, syncPositionNewEDU) mustEqualPositions(t, pos, syncPositionNewEDU)
wg.Done() wg.Done()

View file

@ -0,0 +1,62 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package producers
import (
"encoding/json"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
// UserAPIProducer produces events for the user API server to consume
type UserAPIReadProducer struct {
Topic string
JetStream nats.JetStreamContext
}
// SendData sends account data to the user API server
func (p *UserAPIReadProducer) SendReadUpdate(userID, roomID string, readPos, fullyReadPos types.StreamPosition) error {
m := &nats.Msg{
Subject: p.Topic,
Header: nats.Header{},
}
m.Header.Set(jetstream.UserID, userID)
m.Header.Set(jetstream.RoomID, roomID)
data := types.ReadUpdate{
UserID: userID,
RoomID: roomID,
Read: readPos,
FullyRead: fullyReadPos,
}
var err error
m.Data, err = json.Marshal(data)
if err != nil {
return err
}
log.WithFields(log.Fields{
"user_id": userID,
"room_id": roomID,
"read_pos": readPos,
"fully_read_pos": fullyReadPos,
}).Tracef("Producing to topic '%s'", p.Topic)
_, err = p.JetStream.PublishMsg(m)
return err
}

View file

@ -0,0 +1,60 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package producers
import (
"encoding/json"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
// UserAPIProducer produces events for the user API server to consume
type UserAPIStreamEventProducer struct {
Topic string
JetStream nats.JetStreamContext
}
// SendData sends account data to the user API server
func (p *UserAPIStreamEventProducer) SendStreamEvent(roomID string, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition) error {
m := &nats.Msg{
Subject: p.Topic,
Header: nats.Header{},
}
m.Header.Set(jetstream.RoomID, roomID)
data := types.StreamedEvent{
Event: event,
StreamPosition: pos,
}
var err error
m.Data, err = json.Marshal(data)
if err != nil {
return err
}
log.WithFields(log.Fields{
"room_id": roomID,
"event_id": event.EventID(),
"event_type": event.Type(),
"stream_pos": pos,
}).Tracef("Producing to topic '%s'", p.Topic)
_, err = p.JetStream.PublishMsg(m)
return err
}

View file

@ -18,6 +18,7 @@ import (
"context" "context"
eduAPI "github.com/matrix-org/dendrite/eduserver/api" eduAPI "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
@ -31,6 +32,7 @@ type Database interface {
MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error)
CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error)
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
@ -138,6 +140,12 @@ type Database interface {
// GetRoomReceipts gets all receipts for a given roomID // GetRoomReceipts gets all receipts for a given roomID
GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error)
// UpsertRoomUnreadNotificationCounts updates the notification statistics about a (user, room) key.
UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
// GetUserUnreadNotificationCounts returns statistics per room a user is interested in.
GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error)
SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error)
SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectContextBeforeEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) SelectContextAfterEvent(ctx context.Context, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error)

View file

@ -0,0 +1,108 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
)
func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, error) {
_, err := db.Exec(notificationDataSchema)
if err != nil {
return nil, err
}
r := &notificationDataStatements{}
return r, sqlutil.StatementList{
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL},
{&r.selectMaxID, selectMaxNotificationIDSQL},
}.Prepare(db)
}
type notificationDataStatements struct {
upsertRoomUnreadCounts *sql.Stmt
selectUserUnreadCounts *sql.Stmt
selectMaxID *sql.Stmt
}
const notificationDataSchema = `
CREATE TABLE IF NOT EXISTS syncapi_notification_data (
id BIGSERIAL PRIMARY KEY,
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
notification_count BIGINT NOT NULL DEFAULT 0,
highlight_count BIGINT NOT NULL DEFAULT 0,
CONSTRAINT syncapi_notification_data_unique UNIQUE (user_id, room_id)
);`
const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_data
(user_id, room_id, notification_count, highlight_count)
VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, room_id)
DO UPDATE SET notification_count = $3, highlight_count = $4
RETURNING id`
const selectUserUnreadNotificationCountsSQL = `SELECT
id, room_id, notification_count, highlight_count
FROM syncapi_notification_data
WHERE
user_id = $1 AND
id BETWEEN $2 + 1 AND $3`
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
err = r.upsertRoomUnreadCounts.QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos)
return
}
func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) {
rows, err := r.selectUserUnreadCounts.QueryContext(ctx, userID, fromExcl, toIncl)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed")
roomCounts := map[string]*eventutil.NotificationData{}
for rows.Next() {
var id types.StreamPosition
var roomID string
var notificationCount, highlightCount int
if err = rows.Scan(&id, &roomID, &notificationCount, &highlightCount); err != nil {
return nil, err
}
roomCounts[roomID] = &eventutil.NotificationData{
RoomID: roomID,
UnreadNotificationCount: notificationCount,
UnreadHighlightCount: highlightCount,
}
}
return roomCounts, rows.Err()
}
func (r *notificationDataStatements) SelectMaxID(ctx context.Context) (int64, error) {
var id int64
err := r.selectMaxID.QueryRowContext(ctx).Scan(&id)
return id, err
}

View file

@ -90,6 +90,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
if err != nil { if err != nil {
return nil, err return nil, err
} }
notificationData, err := NewPostgresNotificationDataTable(d.db)
if err != nil {
return nil, err
}
m := sqlutil.NewMigrations() m := sqlutil.NewMigrations()
deltas.LoadFixSequences(m) deltas.LoadFixSequences(m)
deltas.LoadRemoveSendToDeviceSentColumn(m) deltas.LoadRemoveSendToDeviceSentColumn(m)
@ -110,6 +114,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
SendToDevice: sendToDevice, SendToDevice: sendToDevice,
Receipts: receipts, Receipts: receipts,
Memberships: memberships, Memberships: memberships,
NotificationData: notificationData,
} }
return &d, nil return &d, nil
} }

View file

@ -48,6 +48,7 @@ type Database struct {
Filter tables.Filter Filter tables.Filter
Receipts tables.Receipts Receipts tables.Receipts
Memberships tables.Memberships Memberships tables.Memberships
NotificationData tables.NotificationData
} }
func (d *Database) readOnlySnapshot(ctx context.Context) (*sql.Tx, error) { func (d *Database) readOnlySnapshot(ctx context.Context) (*sql.Tx, error) {
@ -102,6 +103,14 @@ func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.S
return types.StreamPosition(id), nil return types.StreamPosition(id), nil
} }
func (d *Database) MaxStreamPositionForNotificationData(ctx context.Context) (types.StreamPosition, error) {
id, err := d.NotificationData.SelectMaxID(ctx)
if err != nil {
return 0, fmt.Errorf("d.NotificationData.SelectMaxID: %w", err)
}
return types.StreamPosition(id), nil
}
func (d *Database) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) { func (d *Database) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) {
return d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilterPart, excludeEventIDs) return d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilterPart, excludeEventIDs)
} }
@ -956,6 +965,18 @@ func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, stream
return receipts, err return receipts, err
} }
func (d *Database) UpsertRoomUnreadNotificationCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
pos, err = d.NotificationData.UpsertRoomUnreadCounts(ctx, userID, roomID, notificationCount, highlightCount)
return err
})
return
}
func (d *Database) GetUserUnreadNotificationCounts(ctx context.Context, userID string, from, to types.StreamPosition) (map[string]*eventutil.NotificationData, error) {
return d.NotificationData.SelectUserUnreadCounts(ctx, userID, from, to)
}
func (s *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) { func (s *Database) SelectContextEvent(ctx context.Context, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) {
return s.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID) return s.OutputEvents.SelectContextEvent(ctx, nil, roomID, eventID)
} }

View file

@ -0,0 +1,108 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
)
func NewSqliteNotificationDataTable(db *sql.DB) (tables.NotificationData, error) {
_, err := db.Exec(notificationDataSchema)
if err != nil {
return nil, err
}
r := &notificationDataStatements{}
return r, sqlutil.StatementList{
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL},
{&r.selectMaxID, selectMaxNotificationIDSQL},
}.Prepare(db)
}
type notificationDataStatements struct {
upsertRoomUnreadCounts *sql.Stmt
selectUserUnreadCounts *sql.Stmt
selectMaxID *sql.Stmt
}
const notificationDataSchema = `
CREATE TABLE IF NOT EXISTS syncapi_notification_data (
id INTEGER PRIMARY KEY,
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
notification_count BIGINT NOT NULL DEFAULT 0,
highlight_count BIGINT NOT NULL DEFAULT 0,
CONSTRAINT syncapi_notifications_unique UNIQUE (user_id, room_id)
);`
const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_data
(user_id, room_id, notification_count, highlight_count)
VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, room_id)
DO UPDATE SET notification_count = $3, highlight_count = $4
RETURNING id`
const selectUserUnreadNotificationCountsSQL = `SELECT
id, room_id, notification_count, highlight_count
FROM syncapi_notification_data
WHERE
user_id = $1 AND
id BETWEEN $2 + 1 AND $3`
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
err = r.upsertRoomUnreadCounts.QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos)
return
}
func (r *notificationDataStatements) SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error) {
rows, err := r.selectUserUnreadCounts.QueryContext(ctx, userID, fromExcl, toIncl)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectUserUnreadCounts: rows.close() failed")
roomCounts := map[string]*eventutil.NotificationData{}
for rows.Next() {
var id types.StreamPosition
var roomID string
var notificationCount, highlightCount int
if err = rows.Scan(&id, &roomID, &notificationCount, &highlightCount); err != nil {
return nil, err
}
roomCounts[roomID] = &eventutil.NotificationData{
RoomID: roomID,
UnreadNotificationCount: notificationCount,
UnreadHighlightCount: highlightCount,
}
}
return roomCounts, rows.Err()
}
func (r *notificationDataStatements) SelectMaxID(ctx context.Context) (int64, error) {
var id int64
err := r.selectMaxID.QueryRowContext(ctx).Scan(&id)
return id, err
}

View file

@ -62,16 +62,19 @@ const selectEventsSQL = "" +
const selectRecentEventsSQL = "" + const selectRecentEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" " WHERE room_id = $1 AND id > $2 AND id <= $3"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectRecentEventsForSyncSQL = "" + const selectRecentEventsForSyncSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectEarlyEventsSQL = "" + const selectEarlyEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" " WHERE room_id = $1 AND id > $2 AND id <= $3"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectMaxEventIDSQL = "" + const selectMaxEventIDSQL = "" +
@ -85,6 +88,7 @@ const selectStateInRangeSQL = "" +
" FROM syncapi_output_room_events" + " FROM syncapi_output_room_events" +
" WHERE (id > $1 AND id <= $2)" + " WHERE (id > $1 AND id <= $2)" +
" AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))" " AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const deleteEventsForRoomSQL = "" + const deleteEventsForRoomSQL = "" +
@ -95,10 +99,12 @@ const selectContextEventSQL = "" +
const selectContextBeforeEventSQL = "" + const selectContextBeforeEventSQL = "" +
"SELECT headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND id < $2" "SELECT headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND id < $2"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectContextAfterEventSQL = "" + const selectContextAfterEventSQL = "" +
"SELECT id, headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND id > $2" "SELECT id, headered_event_json FROM syncapi_output_room_events WHERE room_id = $1 AND id > $2"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
type outputRoomEventsStatements struct { type outputRoomEventsStatements struct {

View file

@ -100,6 +100,10 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er
if err != nil { if err != nil {
return err return err
} }
notificationData, err := NewSqliteNotificationDataTable(d.db)
if err != nil {
return err
}
m := sqlutil.NewMigrations() m := sqlutil.NewMigrations()
deltas.LoadFixSequences(m) deltas.LoadFixSequences(m)
deltas.LoadRemoveSendToDeviceSentColumn(m) deltas.LoadRemoveSendToDeviceSentColumn(m)
@ -120,6 +124,7 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er
SendToDevice: sendToDevice, SendToDevice: sendToDevice,
Receipts: receipts, Receipts: receipts,
Memberships: memberships, Memberships: memberships,
NotificationData: notificationData,
} }
return nil return nil
} }

View file

@ -19,6 +19,7 @@ import (
"database/sql" "database/sql"
eduAPI "github.com/matrix-org/dendrite/eduserver/api" eduAPI "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -171,3 +172,9 @@ type Memberships interface {
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error) SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error)
} }
type NotificationData interface {
UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error)
SelectUserUnreadCounts(ctx context.Context, userID string, fromExcl, toIncl types.StreamPosition) (map[string]*eventutil.NotificationData, error)
SelectMaxID(ctx context.Context) (int64, error)
}

View file

@ -0,0 +1,55 @@
package streams
import (
"context"
"github.com/matrix-org/dendrite/syncapi/types"
)
type NotificationDataStreamProvider struct {
StreamProvider
}
func (p *NotificationDataStreamProvider) Setup() {
p.StreamProvider.Setup()
id, err := p.DB.MaxStreamPositionForNotificationData(context.Background())
if err != nil {
panic(err)
}
p.latest = id
}
func (p *NotificationDataStreamProvider) CompleteSync(
ctx context.Context,
req *types.SyncRequest,
) types.StreamPosition {
return p.IncrementalSync(ctx, req, 0, p.LatestPosition(ctx))
}
func (p *NotificationDataStreamProvider) IncrementalSync(
ctx context.Context,
req *types.SyncRequest,
from, to types.StreamPosition,
) types.StreamPosition {
// We want counts for all possible rooms, so always start from zero.
countsByRoom, err := p.DB.GetUserUnreadNotificationCounts(ctx, req.Device.UserID, from, to)
if err != nil {
req.Log.WithError(err).Error("GetUserUnreadNotificationCounts failed")
return from
}
// We're merely decorating existing rooms. Note that the Join map
// values are not pointers.
for roomID, jr := range req.Response.Rooms.Join {
counts := countsByRoom[roomID]
if counts == nil {
continue
}
jr.UnreadNotifications.HighlightCount = counts.UnreadHighlightCount
jr.UnreadNotifications.NotificationCount = counts.UnreadNotificationCount
req.Response.Rooms.Join[roomID] = jr
}
return to
}

View file

@ -19,6 +19,7 @@ type Streams struct {
SendToDeviceStreamProvider types.StreamProvider SendToDeviceStreamProvider types.StreamProvider
AccountDataStreamProvider types.StreamProvider AccountDataStreamProvider types.StreamProvider
DeviceListStreamProvider types.StreamProvider DeviceListStreamProvider types.StreamProvider
NotificationDataStreamProvider types.StreamProvider
} }
func NewSyncStreamProviders( func NewSyncStreamProviders(
@ -47,6 +48,9 @@ func NewSyncStreamProviders(
StreamProvider: StreamProvider{DB: d}, StreamProvider: StreamProvider{DB: d},
userAPI: userAPI, userAPI: userAPI,
}, },
NotificationDataStreamProvider: &NotificationDataStreamProvider{
StreamProvider: StreamProvider{DB: d},
},
DeviceListStreamProvider: &DeviceListStreamProvider{ DeviceListStreamProvider: &DeviceListStreamProvider{
StreamProvider: StreamProvider{DB: d}, StreamProvider: StreamProvider{DB: d},
rsAPI: rsAPI, rsAPI: rsAPI,
@ -60,6 +64,7 @@ func NewSyncStreamProviders(
streams.InviteStreamProvider.Setup() streams.InviteStreamProvider.Setup()
streams.SendToDeviceStreamProvider.Setup() streams.SendToDeviceStreamProvider.Setup()
streams.AccountDataStreamProvider.Setup() streams.AccountDataStreamProvider.Setup()
streams.NotificationDataStreamProvider.Setup()
streams.DeviceListStreamProvider.Setup() streams.DeviceListStreamProvider.Setup()
return streams return streams
@ -73,6 +78,7 @@ func (s *Streams) Latest(ctx context.Context) types.StreamingToken {
InvitePosition: s.InviteStreamProvider.LatestPosition(ctx), InvitePosition: s.InviteStreamProvider.LatestPosition(ctx),
SendToDevicePosition: s.SendToDeviceStreamProvider.LatestPosition(ctx), SendToDevicePosition: s.SendToDeviceStreamProvider.LatestPosition(ctx),
AccountDataPosition: s.AccountDataStreamProvider.LatestPosition(ctx), AccountDataPosition: s.AccountDataStreamProvider.LatestPosition(ctx),
NotificationDataPosition: s.NotificationDataStreamProvider.LatestPosition(ctx),
DeviceListPosition: s.DeviceListStreamProvider.LatestPosition(ctx), DeviceListPosition: s.DeviceListStreamProvider.LatestPosition(ctx),
} }
} }

View file

@ -189,7 +189,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
currentPos.ApplyUpdates(userStreamListener.GetSyncPosition()) currentPos.ApplyUpdates(userStreamListener.GetSyncPosition())
} }
} else { } else {
syncReq.Log.Debugln("Responding to sync immediately") syncReq.Log.WithField("currentPos", currentPos).Debugln("Responding to sync immediately")
} }
if syncReq.Since.IsEmpty() { if syncReq.Since.IsEmpty() {
@ -213,6 +213,9 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
AccountDataPosition: rp.streams.AccountDataStreamProvider.CompleteSync( AccountDataPosition: rp.streams.AccountDataStreamProvider.CompleteSync(
syncReq.Context, syncReq, syncReq.Context, syncReq,
), ),
NotificationDataPosition: rp.streams.NotificationDataStreamProvider.CompleteSync(
syncReq.Context, syncReq,
),
DeviceListPosition: rp.streams.DeviceListStreamProvider.CompleteSync( DeviceListPosition: rp.streams.DeviceListStreamProvider.CompleteSync(
syncReq.Context, syncReq, syncReq.Context, syncReq,
), ),
@ -244,6 +247,10 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
syncReq.Context, syncReq, syncReq.Context, syncReq,
syncReq.Since.AccountDataPosition, currentPos.AccountDataPosition, syncReq.Since.AccountDataPosition, currentPos.AccountDataPosition,
), ),
NotificationDataPosition: rp.streams.NotificationDataStreamProvider.IncrementalSync(
syncReq.Context, syncReq,
syncReq.Since.NotificationDataPosition, currentPos.NotificationDataPosition,
),
DeviceListPosition: rp.streams.DeviceListStreamProvider.IncrementalSync( DeviceListPosition: rp.streams.DeviceListStreamProvider.IncrementalSync(
syncReq.Context, syncReq, syncReq.Context, syncReq,
syncReq.Since.DeviceListPosition, currentPos.DeviceListPosition, syncReq.Since.DeviceListPosition, currentPos.DeviceListPosition,

View file

@ -31,6 +31,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/consumers" "github.com/matrix-org/dendrite/syncapi/consumers"
"github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/dendrite/syncapi/notifier"
"github.com/matrix-org/dendrite/syncapi/producers"
"github.com/matrix-org/dendrite/syncapi/routing" "github.com/matrix-org/dendrite/syncapi/routing"
"github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/streams"
@ -64,6 +65,18 @@ func AddPublicRoutes(
requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier) requestPool := sync.NewRequestPool(syncDB, cfg, userAPI, keyAPI, rsAPI, streams, notifier)
userAPIStreamEventProducer := &producers.UserAPIStreamEventProducer{
JetStream: js,
Topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputStreamEvent),
}
userAPIReadUpdateProducer := &producers.UserAPIReadProducer{
JetStream: js,
Topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReadUpdate),
}
_ = userAPIReadUpdateProducer
keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer( keyChangeConsumer := consumers.NewOutputKeyChangeEventConsumer(
process, cfg, cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent), process, cfg, cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent),
js, keyAPI, rsAPI, syncDB, notifier, js, keyAPI, rsAPI, syncDB, notifier,
@ -75,7 +88,7 @@ func AddPublicRoutes(
roomConsumer := consumers.NewOutputRoomEventConsumer( roomConsumer := consumers.NewOutputRoomEventConsumer(
process, cfg, js, syncDB, notifier, streams.PDUStreamProvider, process, cfg, js, syncDB, notifier, streams.PDUStreamProvider,
streams.InviteStreamProvider, rsAPI, streams.InviteStreamProvider, rsAPI, userAPIStreamEventProducer,
) )
if err = roomConsumer.Start(); err != nil { if err = roomConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start room server consumer") logrus.WithError(err).Panicf("failed to start room server consumer")
@ -83,11 +96,19 @@ func AddPublicRoutes(
clientConsumer := consumers.NewOutputClientDataConsumer( clientConsumer := consumers.NewOutputClientDataConsumer(
process, cfg, js, syncDB, notifier, streams.AccountDataStreamProvider, process, cfg, js, syncDB, notifier, streams.AccountDataStreamProvider,
userAPIReadUpdateProducer,
) )
if err = clientConsumer.Start(); err != nil { if err = clientConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start client data consumer") logrus.WithError(err).Panicf("failed to start client data consumer")
} }
notificationConsumer := consumers.NewOutputNotificationDataConsumer(
process, cfg, js, syncDB, notifier, streams.NotificationDataStreamProvider,
)
if err = notificationConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start notification data consumer")
}
typingConsumer := consumers.NewOutputTypingEventConsumer( typingConsumer := consumers.NewOutputTypingEventConsumer(
process, cfg, js, syncDB, eduCache, notifier, streams.TypingStreamProvider, process, cfg, js, syncDB, eduCache, notifier, streams.TypingStreamProvider,
) )
@ -104,6 +125,7 @@ func AddPublicRoutes(
receiptConsumer := consumers.NewOutputReceiptEventConsumer( receiptConsumer := consumers.NewOutputReceiptEventConsumer(
process, cfg, js, syncDB, notifier, streams.ReceiptStreamProvider, process, cfg, js, syncDB, notifier, streams.ReceiptStreamProvider,
userAPIReadUpdateProducer,
) )
if err = receiptConsumer.Start(); err != nil { if err = receiptConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start receipts consumer") logrus.WithError(err).Panicf("failed to start receipts consumer")

View file

@ -102,6 +102,7 @@ type StreamingToken struct {
InvitePosition StreamPosition InvitePosition StreamPosition
AccountDataPosition StreamPosition AccountDataPosition StreamPosition
DeviceListPosition StreamPosition DeviceListPosition StreamPosition
NotificationDataPosition StreamPosition
} }
// This will be used as a fallback by json.Marshal. // This will be used as a fallback by json.Marshal.
@ -117,10 +118,11 @@ func (s *StreamingToken) UnmarshalText(text []byte) (err error) {
func (t StreamingToken) String() string { func (t StreamingToken) String() string {
posStr := fmt.Sprintf( posStr := fmt.Sprintf(
"s%d_%d_%d_%d_%d_%d_%d", "s%d_%d_%d_%d_%d_%d_%d_%d",
t.PDUPosition, t.TypingPosition, t.PDUPosition, t.TypingPosition,
t.ReceiptPosition, t.SendToDevicePosition, t.ReceiptPosition, t.SendToDevicePosition,
t.InvitePosition, t.AccountDataPosition, t.DeviceListPosition, t.InvitePosition, t.AccountDataPosition,
t.DeviceListPosition, t.NotificationDataPosition,
) )
return posStr return posStr
} }
@ -142,12 +144,14 @@ func (t *StreamingToken) IsAfter(other StreamingToken) bool {
return true return true
case t.DeviceListPosition > other.DeviceListPosition: case t.DeviceListPosition > other.DeviceListPosition:
return true return true
case t.NotificationDataPosition > other.NotificationDataPosition:
return true
} }
return false return false
} }
func (t *StreamingToken) IsEmpty() bool { func (t *StreamingToken) IsEmpty() bool {
return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition+t.AccountDataPosition+t.DeviceListPosition == 0 return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition+t.AccountDataPosition+t.DeviceListPosition+t.NotificationDataPosition == 0
} }
// WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken. // WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken.
@ -185,6 +189,9 @@ func (t *StreamingToken) ApplyUpdates(other StreamingToken) {
if other.DeviceListPosition > t.DeviceListPosition { if other.DeviceListPosition > t.DeviceListPosition {
t.DeviceListPosition = other.DeviceListPosition t.DeviceListPosition = other.DeviceListPosition
} }
if other.NotificationDataPosition > t.NotificationDataPosition {
t.NotificationDataPosition = other.NotificationDataPosition
}
} }
type TopologyToken struct { type TopologyToken struct {
@ -277,7 +284,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
// s478_0_0_0_0_13.dl-0-2 but we have now removed partitioned stream positions // s478_0_0_0_0_13.dl-0-2 but we have now removed partitioned stream positions
tok = strings.Split(tok, ".")[0] tok = strings.Split(tok, ".")[0]
parts := strings.Split(tok[1:], "_") parts := strings.Split(tok[1:], "_")
var positions [7]StreamPosition var positions [8]StreamPosition
for i, p := range parts { for i, p := range parts {
if i >= len(positions) { if i >= len(positions) {
break break
@ -298,6 +305,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
InvitePosition: positions[4], InvitePosition: positions[4],
AccountDataPosition: positions[5], AccountDataPosition: positions[5],
DeviceListPosition: positions[6], DeviceListPosition: positions[6],
NotificationDataPosition: positions[7],
} }
return token, nil return token, nil
} }
@ -383,6 +391,10 @@ type JoinResponse struct {
AccountData struct { AccountData struct {
Events []gomatrixserverlib.ClientEvent `json:"events"` Events []gomatrixserverlib.ClientEvent `json:"events"`
} `json:"account_data"` } `json:"account_data"`
UnreadNotifications struct {
HighlightCount int `json:"highlight_count"`
NotificationCount int `json:"notification_count"`
} `json:"unread_notifications"`
} }
// NewJoinResponse creates an empty response with initialised arrays. // NewJoinResponse creates an empty response with initialised arrays.
@ -462,3 +474,16 @@ type Peek struct {
New bool New bool
Deleted bool Deleted bool
} }
type ReadUpdate struct {
UserID string `json:"user_id"`
RoomID string `json:"room_id"`
Read StreamPosition `json:"read,omitempty"`
FullyRead StreamPosition `json:"fully_read,omitempty"`
}
// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event.
type StreamedEvent struct {
Event *gomatrixserverlib.HeaderedEvent `json:"event"`
StreamPosition StreamPosition `json:"stream_position"`
}

View file

@ -9,9 +9,9 @@ import (
func TestSyncTokens(t *testing.T) { func TestSyncTokens(t *testing.T) {
shouldPass := map[string]string{ shouldPass := map[string]string{
"s4_0_0_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, 0, 0}.String(), "s4_0_0_0_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, 0, 0, 0}.String(),
"s3_1_0_0_0_0_2": StreamingToken{3, 1, 0, 0, 0, 0, 2}.String(), "s3_1_0_0_0_0_2_0": StreamingToken{3, 1, 0, 0, 0, 0, 2, 0}.String(),
"s3_1_2_3_5_0_0": StreamingToken{3, 1, 2, 3, 5, 0, 0}.String(), "s3_1_2_3_5_0_0_0": StreamingToken{3, 1, 2, 3, 5, 0, 0, 0}.String(),
"t3_1": TopologyToken{3, 1}.String(), "t3_1": TopologyToken{3, 1}.String(),
} }

View file

@ -30,3 +30,9 @@ Local device key changes appear in /keys/changes
Remove group category Remove group category
Remove group role Remove group role
# Flakey
AS-ghosted users can use rooms themselves
# Flakey, need additional investigation
Messages that notify from another user increment notification_count
Messages that highlight from another user increment unread highlight count

View file

@ -588,6 +588,59 @@ User can invite remote user to room with version 9
Remote user can backfill in a room with version 9 Remote user can backfill in a room with version 9
Can reject invites over federation for rooms with version 9 Can reject invites over federation for rooms with version 9
Can receive redactions from regular users over federation in room version 9 Can receive redactions from regular users over federation in room version 9
Pushers created with a different access token are deleted on password change
Pushers created with a the same access token are not deleted on password change
Can fetch a user's pushers
Can add global push rule for room
Can add global push rule for sender
Can add global push rule for content
Can add global push rule for override
Can add global push rule for underride
Can add global push rule for content
New rules appear before old rules by default
Can add global push rule before an existing rule
Can add global push rule after an existing rule
Can delete a push rule
Can disable a push rule
Adding the same push rule twice is idempotent
Can change the actions of default rules
Can change the actions of a user specified rule
Adding a push rule wakes up an incremental /sync
Disabling a push rule wakes up an incremental /sync
Enabling a push rule wakes up an incremental /sync
Setting actions for a push rule wakes up an incremental /sync
Can enable/disable default rules
Trying to add push rule with missing template fails with 400
Trying to add push rule with missing rule_id fails with 400
Trying to add push rule with empty rule_id fails with 400
Trying to add push rule with invalid template fails with 400
Trying to add push rule with rule_id with slashes fails with 400
Trying to add push rule with override rule without conditions fails with 400
Trying to add push rule with underride rule without conditions fails with 400
Trying to add push rule with condition without kind fails with 400
Trying to add push rule with content rule without pattern fails with 400
Trying to add push rule with no actions fails with 400
Trying to add push rule with invalid action fails with 400
Trying to add push rule with invalid attr fails with 400
Trying to add push rule with invalid value for enabled fails with 400
Trying to get push rules with no trailing slash fails with 400
Trying to get push rules with scope without trailing slash fails with 400
Trying to get push rules with template without tailing slash fails with 400
Trying to get push rules with unknown scope fails with 400
Trying to get push rules with unknown template fails with 400
Trying to get push rules with unknown attribute fails with 400
Getting push rules doesn't corrupt the cache SYN-390
Test that a message is pushed
Invites are pushed
Rooms with names are correctly named in pushes
Rooms with canonical alias are correctly named in pushed
Rooms with many users are correctly pushed
Don't get pushed for rooms you've muted
Rejected events are not pushed
Test that rejected pushers are removed.
Notifications can be viewed with GET /notifications
Trying to add push rule with no scope fails with 400
Trying to add push rule with invalid scope fails with 400
Forward extremities remain so even after the next events are populated as outliers Forward extremities remain so even after the next events are populated as outliers
If a device list update goes missing, the server resyncs on the next one If a device list update goes missing, the server resyncs on the next one
uploading self-signing key notifies over federation uploading self-signing key notifies over federation
@ -607,4 +660,4 @@ registration accepts non-ascii passwords
registration with inhibit_login inhibits login registration with inhibit_login inhibits login
The operation must be consistent through an interactive authentication session The operation must be consistent through an interactive authentication session
Multiple calls to /sync should not cause 500 errors Multiple calls to /sync should not cause 500 errors
/context/ with lazy_load_members filter works

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/pushrules"
) )
// UserInternalAPI is the internal API for information about users and devices. // UserInternalAPI is the internal API for information about users and devices.
@ -28,6 +29,7 @@ type UserInternalAPI interface {
LoginTokenInternalAPI LoginTokenInternalAPI
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
@ -37,6 +39,10 @@ type UserInternalAPI interface {
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error
PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error
PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error
QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse)
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
@ -45,6 +51,9 @@ type UserInternalAPI interface {
QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error
QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error
QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error
QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
} }
type PerformKeyBackupRequest struct { type PerformKeyBackupRequest struct {
@ -424,3 +433,77 @@ const (
// AccountTypeAppService indicates this is an appservice account // AccountTypeAppService indicates this is an appservice account
AccountTypeAppService AccountType = 4 AccountTypeAppService AccountType = 4
) )
type QueryPushersRequest struct {
Localpart string
}
type QueryPushersResponse struct {
Pushers []Pusher `json:"pushers"`
}
type PerformPusherSetRequest struct {
Pusher // Anonymous field because that's how clientapi unmarshals it.
Localpart string
Append bool `json:"append"`
}
type PerformPusherDeletionRequest struct {
Localpart string
SessionID int64
}
// Pusher represents a push notification subscriber
type Pusher struct {
SessionID int64 `json:"session_id,omitempty"`
PushKey string `json:"pushkey"`
PushKeyTS gomatrixserverlib.Timestamp `json:"pushkey_ts,omitempty"`
Kind PusherKind `json:"kind"`
AppID string `json:"app_id"`
AppDisplayName string `json:"app_display_name"`
DeviceDisplayName string `json:"device_display_name"`
ProfileTag string `json:"profile_tag"`
Language string `json:"lang"`
Data map[string]interface{} `json:"data"`
}
type PusherKind string
const (
EmailKind PusherKind = "email"
HTTPKind PusherKind = "http"
)
type PerformPushRulesPutRequest struct {
UserID string `json:"user_id"`
RuleSets *pushrules.AccountRuleSets `json:"rule_sets"`
}
type QueryPushRulesRequest struct {
UserID string `json:"user_id"`
}
type QueryPushRulesResponse struct {
RuleSets *pushrules.AccountRuleSets `json:"rule_sets"`
}
type QueryNotificationsRequest struct {
Localpart string `json:"localpart"` // Required.
From string `json:"from,omitempty"`
Limit int `json:"limit,omitempty"`
Only string `json:"only,omitempty"`
}
type QueryNotificationsResponse struct {
NextToken string `json:"next_token"`
Notifications []*Notification `json:"notifications"` // Required.
}
type Notification struct {
Actions []*pushrules.Action `json:"actions"` // Required.
Event gomatrixserverlib.ClientEvent `json:"event"` // Required.
ProfileTag string `json:"profile_tag"` // Required by Sytest, but actually optional.
Read bool `json:"read"` // Required.
RoomID string `json:"room_id"` // Required.
TS gomatrixserverlib.Timestamp `json:"ts"` // Required.
}

View file

@ -79,6 +79,21 @@ func (t *UserInternalAPITrace) PerformKeyBackup(ctx context.Context, req *Perfor
util.GetLogger(ctx).Infof("PerformKeyBackup req=%+v res=%+v", js(req), js(res)) util.GetLogger(ctx).Infof("PerformKeyBackup req=%+v res=%+v", js(req), js(res))
return err return err
} }
func (t *UserInternalAPITrace) PerformPusherSet(ctx context.Context, req *PerformPusherSetRequest, res *struct{}) error {
err := t.Impl.PerformPusherSet(ctx, req, res)
util.GetLogger(ctx).Infof("PerformPusherSet req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) PerformPusherDeletion(ctx context.Context, req *PerformPusherDeletionRequest, res *struct{}) error {
err := t.Impl.PerformPusherDeletion(ctx, req, res)
util.GetLogger(ctx).Infof("PerformPusherDeletion req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) PerformPushRulesPut(ctx context.Context, req *PerformPushRulesPutRequest, res *struct{}) error {
err := t.Impl.PerformPushRulesPut(ctx, req, res)
util.GetLogger(ctx).Infof("PerformPushRulesPut req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) { func (t *UserInternalAPITrace) QueryKeyBackup(ctx context.Context, req *QueryKeyBackupRequest, res *QueryKeyBackupResponse) {
t.Impl.QueryKeyBackup(ctx, req, res) t.Impl.QueryKeyBackup(ctx, req, res)
util.GetLogger(ctx).Infof("QueryKeyBackup req=%+v res=%+v", js(req), js(res)) util.GetLogger(ctx).Infof("QueryKeyBackup req=%+v res=%+v", js(req), js(res))
@ -118,6 +133,21 @@ func (t *UserInternalAPITrace) QueryOpenIDToken(ctx context.Context, req *QueryO
util.GetLogger(ctx).Infof("QueryOpenIDToken req=%+v res=%+v", js(req), js(res)) util.GetLogger(ctx).Infof("QueryOpenIDToken req=%+v res=%+v", js(req), js(res))
return err return err
} }
func (t *UserInternalAPITrace) QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error {
err := t.Impl.QueryPushers(ctx, req, res)
util.GetLogger(ctx).Infof("QueryPushers req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) QueryPushRules(ctx context.Context, req *QueryPushRulesRequest, res *QueryPushRulesResponse) error {
err := t.Impl.QueryPushRules(ctx, req, res)
util.GetLogger(ctx).Infof("QueryPushRules req=%+v res=%+v", js(req), js(res))
return err
}
func (t *UserInternalAPITrace) QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error {
err := t.Impl.QueryNotifications(ctx, req, res)
util.GetLogger(ctx).Infof("QueryNotifications req=%+v res=%+v", js(req), js(res))
return err
}
func js(thing interface{}) string { func js(thing interface{}) string {
b, err := json.Marshal(thing) b, err := json.Marshal(thing)

View file

@ -0,0 +1,136 @@
package consumers
import (
"context"
"encoding/json"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/types"
uapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/util"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
type OutputReadUpdateConsumer struct {
ctx context.Context
cfg *config.UserAPI
jetstream nats.JetStreamContext
durable string
db storage.Database
pgClient pushgateway.Client
ServerName gomatrixserverlib.ServerName
topic string
userAPI uapi.UserInternalAPI
syncProducer *producers.SyncAPI
}
func NewOutputReadUpdateConsumer(
process *process.ProcessContext,
cfg *config.UserAPI,
js nats.JetStreamContext,
store storage.Database,
pgClient pushgateway.Client,
userAPI uapi.UserInternalAPI,
syncProducer *producers.SyncAPI,
) *OutputReadUpdateConsumer {
return &OutputReadUpdateConsumer{
ctx: process.Context(),
cfg: cfg,
jetstream: js,
db: store,
ServerName: cfg.Matrix.ServerName,
durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIReadUpdateConsumer"),
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReadUpdate),
pgClient: pgClient,
userAPI: userAPI,
syncProducer: syncProducer,
}
}
func (s *OutputReadUpdateConsumer) Start() error {
if err := jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
); err != nil {
return err
}
return nil
}
func (s *OutputReadUpdateConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
var read types.ReadUpdate
if err := json.Unmarshal(msg.Data, &read); err != nil {
log.WithError(err).Error("userapi clientapi consumer: message parse failure")
return true
}
if read.FullyRead == 0 && read.Read == 0 {
return true
}
userID := string(msg.Header.Get(jetstream.UserID))
roomID := string(msg.Header.Get(jetstream.RoomID))
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
log.WithError(err).Error("userapi clientapi consumer: SplitID failure")
return true
}
if domain != s.ServerName {
log.Error("userapi clientapi consumer: not a local user")
return true
}
log := log.WithFields(log.Fields{
"room_id": roomID,
"user_id": userID,
})
log.Tracef("Received read update from sync API: %#v", read)
if read.Read > 0 {
updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, int64(read.Read), true)
if err != nil {
log.WithError(err).Error("userapi EDU consumer")
return false
}
if updated {
if err = s.syncProducer.GetAndSendNotificationData(ctx, userID, roomID); err != nil {
log.WithError(err).Error("userapi EDU consumer: GetAndSendNotificationData failed")
return false
}
if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed")
return false
}
}
}
if read.FullyRead > 0 {
deleted, err := s.db.DeleteNotificationsUpTo(ctx, localpart, roomID, int64(read.FullyRead))
if err != nil {
log.WithError(err).Errorf("userapi clientapi consumer: DeleteNotificationsUpTo failed")
return false
}
if deleted {
if err := util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil {
log.WithError(err).Error("userapi clientapi consumer: NotifyUserCounts failed")
return false
}
if err := s.syncProducer.GetAndSendNotificationData(ctx, userID, read.RoomID); err != nil {
log.WithError(err).Errorf("userapi clientapi consumer: GetAndSendNotificationData failed")
return false
}
}
}
return true
}

View file

@ -0,0 +1,588 @@
package consumers
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/internal/pushrules"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/userapi/util"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
type OutputStreamEventConsumer struct {
ctx context.Context
cfg *config.UserAPI
userAPI api.UserInternalAPI
rsAPI rsapi.RoomserverInternalAPI
jetstream nats.JetStreamContext
durable string
db storage.Database
topic string
pgClient pushgateway.Client
syncProducer *producers.SyncAPI
}
func NewOutputStreamEventConsumer(
process *process.ProcessContext,
cfg *config.UserAPI,
js nats.JetStreamContext,
store storage.Database,
pgClient pushgateway.Client,
userAPI api.UserInternalAPI,
rsAPI rsapi.RoomserverInternalAPI,
syncProducer *producers.SyncAPI,
) *OutputStreamEventConsumer {
return &OutputStreamEventConsumer{
ctx: process.Context(),
cfg: cfg,
jetstream: js,
db: store,
durable: cfg.Matrix.JetStream.Durable("UserAPISyncAPIStreamEventConsumer"),
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputStreamEvent),
pgClient: pgClient,
userAPI: userAPI,
rsAPI: rsAPI,
syncProducer: syncProducer,
}
}
func (s *OutputStreamEventConsumer) Start() error {
if err := jetstream.JetStreamConsumer(
s.ctx, s.jetstream, s.topic, s.durable, s.onMessage,
nats.DeliverAll(), nats.ManualAck(),
); err != nil {
return err
}
return nil
}
func (s *OutputStreamEventConsumer) onMessage(ctx context.Context, msg *nats.Msg) bool {
var output types.StreamedEvent
output.Event = &gomatrixserverlib.HeaderedEvent{}
if err := json.Unmarshal(msg.Data, &output); err != nil {
log.WithError(err).Errorf("userapi consumer: message parse failure")
return true
}
if output.Event.Event == nil {
log.Errorf("userapi consumer: expected event")
return true
}
log.WithFields(log.Fields{
"event_id": output.Event.EventID(),
"event_type": output.Event.Type(),
"stream_pos": output.StreamPosition,
}).Tracef("Received message from sync API: %#v", output)
if err := s.processMessage(ctx, output.Event, int64(output.StreamPosition)); err != nil {
log.WithFields(log.Fields{
"event_id": output.Event.EventID(),
}).WithError(err).Errorf("userapi consumer: process room event failure")
}
return true
}
func (s *OutputStreamEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64) error {
members, roomSize, err := s.localRoomMembers(ctx, event.RoomID())
if err != nil {
return fmt.Errorf("s.localRoomMembers: %w", err)
}
if event.Type() == gomatrixserverlib.MRoomMember {
cevent := gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll)
var member *localMembership
member, err = newLocalMembership(&cevent)
if err != nil {
return fmt.Errorf("newLocalMembership: %w", err)
}
if member.Membership == gomatrixserverlib.Invite && member.Domain == s.cfg.Matrix.ServerName {
// localRoomMembers only adds joined members. An invite
// should also be pushed to the target user.
members = append(members, member)
}
}
// TODO: run in parallel with localRoomMembers.
roomName, err := s.roomName(ctx, event)
if err != nil {
return fmt.Errorf("s.roomName: %w", err)
}
log.WithFields(log.Fields{
"event_id": event.EventID(),
"room_id": event.RoomID(),
"num_members": len(members),
"room_size": roomSize,
}).Tracef("Notifying members")
// Notification.UserIsTarget is a per-member field, so we
// cannot group all users in a single request.
//
// TODO: does it have to be set? It's not required, and
// removing it means we can send all notifications to
// e.g. Element's Push gateway in one go.
for _, mem := range members {
if err := s.notifyLocal(ctx, event, pos, mem, roomSize, roomName); err != nil {
log.WithFields(log.Fields{
"localpart": mem.Localpart,
}).WithError(err).Debugf("Unable to push to local user")
continue
}
}
return nil
}
type localMembership struct {
gomatrixserverlib.MemberContent
UserID string
Localpart string
Domain gomatrixserverlib.ServerName
}
func newLocalMembership(event *gomatrixserverlib.ClientEvent) (*localMembership, error) {
if event.StateKey == nil {
return nil, fmt.Errorf("missing state_key")
}
var member localMembership
if err := json.Unmarshal(event.Content, &member.MemberContent); err != nil {
return nil, err
}
localpart, domain, err := gomatrixserverlib.SplitID('@', *event.StateKey)
if err != nil {
return nil, err
}
member.UserID = *event.StateKey
member.Localpart = localpart
member.Domain = domain
return &member, nil
}
// localRoomMembers fetches the current local members of a room, and
// the total number of members.
func (s *OutputStreamEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) {
req := &rsapi.QueryMembershipsForRoomRequest{
RoomID: roomID,
JoinedOnly: true,
}
var res rsapi.QueryMembershipsForRoomResponse
// XXX: This could potentially race if the state for the event is not known yet
// e.g. the event came over federation but we do not have the full state persisted.
if err := s.rsAPI.QueryMembershipsForRoom(ctx, req, &res); err != nil {
return nil, 0, err
}
var members []*localMembership
var ntotal int
for _, event := range res.JoinEvents {
member, err := newLocalMembership(&event)
if err != nil {
log.WithError(err).Errorf("Parsing MemberContent")
continue
}
if member.Membership != gomatrixserverlib.Join {
continue
}
if member.Domain != s.cfg.Matrix.ServerName {
continue
}
ntotal++
members = append(members, member)
}
return members, ntotal, nil
}
// roomName returns the name in the event (if type==m.room.name), or
// looks it up in roomserver. If there is no name,
// m.room.canonical_alias is consulted. Returns an empty string if the
// room has no name.
func (s *OutputStreamEventConsumer) roomName(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) (string, error) {
if event.Type() == gomatrixserverlib.MRoomName {
name, err := unmarshalRoomName(event)
if err != nil {
return "", err
}
if name != "" {
return name, nil
}
}
req := &rsapi.QueryCurrentStateRequest{
RoomID: event.RoomID(),
StateTuples: []gomatrixserverlib.StateKeyTuple{roomNameTuple, canonicalAliasTuple},
}
var res rsapi.QueryCurrentStateResponse
if err := s.rsAPI.QueryCurrentState(ctx, req, &res); err != nil {
return "", nil
}
if eventS := res.StateEvents[roomNameTuple]; eventS != nil {
return unmarshalRoomName(eventS)
}
if event.Type() == gomatrixserverlib.MRoomCanonicalAlias {
alias, err := unmarshalCanonicalAlias(event)
if err != nil {
return "", err
}
if alias != "" {
return alias, nil
}
}
if event = res.StateEvents[canonicalAliasTuple]; event != nil {
return unmarshalCanonicalAlias(event)
}
return "", nil
}
var (
canonicalAliasTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias}
roomNameTuple = gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomName}
)
func unmarshalRoomName(event *gomatrixserverlib.HeaderedEvent) (string, error) {
var nc eventutil.NameContent
if err := json.Unmarshal(event.Content(), &nc); err != nil {
return "", fmt.Errorf("unmarshaling NameContent: %w", err)
}
return nc.Name, nil
}
func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, error) {
var cac eventutil.CanonicalAliasContent
if err := json.Unmarshal(event.Content(), &cac); err != nil {
return "", fmt.Errorf("unmarshaling CanonicalAliasContent: %w", err)
}
return cac.Alias, nil
}
// notifyLocal finds the right push actions for a local user, given an event.
func (s *OutputStreamEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, pos int64, mem *localMembership, roomSize int, roomName string) error {
actions, err := s.evaluatePushRules(ctx, event, mem, roomSize)
if err != nil {
return err
}
a, tweaks, err := pushrules.ActionsToTweaks(actions)
if err != nil {
return err
}
// TODO: support coalescing.
if a != pushrules.NotifyAction && a != pushrules.CoalesceAction {
log.WithFields(log.Fields{
"event_id": event.EventID(),
"room_id": event.RoomID(),
"localpart": mem.Localpart,
}).Tracef("Push rule evaluation rejected the event")
return nil
}
devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks)
if err != nil {
return err
}
n := &api.Notification{
Actions: actions,
// UNSPEC: the spec doesn't say this is a ClientEvent, but the
// fields seem to match. room_id should be missing, which
// matches the behaviour of FormatSync.
Event: gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatSync),
// TODO: this is per-device, but it's not part of the primary
// key. So inserting one notification per profile tag doesn't
// make sense. What is this supposed to be? Sytests require it
// to "work", but they only use a single device.
ProfileTag: profileTag,
RoomID: event.RoomID(),
TS: gomatrixserverlib.AsTimestamp(time.Now()),
}
if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), pos, tweaks, n); err != nil {
return err
}
if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil {
return err
}
// We do this after InsertNotification. Thus, this should always return >=1.
userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications)
if err != nil {
return err
}
log.WithFields(log.Fields{
"event_id": event.EventID(),
"room_id": event.RoomID(),
"localpart": mem.Localpart,
"num_urls": len(devicesByURLAndFormat),
"num_unread": userNumUnreadNotifs,
}).Tracef("Notifying single member")
// Push gateways are out of our control, and we cannot risk
// looking up the server on a misbehaving push gateway. Each user
// receives a goroutine now that all internal API calls have been
// made.
//
// TODO: think about bounding this to one per user, and what
// ordering guarantees we must provide.
go func() {
// This background processing cannot be tied to a request.
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
var rejected []*pushgateway.Device
for url, fmts := range devicesByURLAndFormat {
for format, devices := range fmts {
// TODO: support "email".
if !strings.HasPrefix(url, "http") {
continue
}
// UNSPEC: the specification suggests there can be
// more than one device per request. There is at least
// one Sytest that expects one HTTP request per
// device, rather than per URL. For now, we must
// notify each one separately.
for _, dev := range devices {
rej, err := s.notifyHTTP(ctx, event, url, format, []*pushgateway.Device{dev}, mem.Localpart, roomName, int(userNumUnreadNotifs))
if err != nil {
log.WithFields(log.Fields{
"event_id": event.EventID(),
"localpart": mem.Localpart,
}).WithError(err).Errorf("Unable to notify HTTP pusher")
continue
}
rejected = append(rejected, rej...)
}
}
}
if len(rejected) > 0 {
s.deleteRejectedPushers(ctx, rejected, mem.Localpart)
}
}()
return nil
}
// evaluatePushRules fetches and evaluates the push rules of a local
// user. Returns actions (including dont_notify).
func (s *OutputStreamEventConsumer) evaluatePushRules(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
if event.Sender() == mem.UserID {
// SPEC: Homeservers MUST NOT notify the Push Gateway for
// events that the user has sent themselves.
return nil, nil
}
var res api.QueryPushRulesResponse
if err := s.userAPI.QueryPushRules(ctx, &api.QueryPushRulesRequest{UserID: mem.UserID}, &res); err != nil {
return nil, err
}
ec := &ruleSetEvalContext{
ctx: ctx,
rsAPI: s.rsAPI,
mem: mem,
roomID: event.RoomID(),
roomSize: roomSize,
}
eval := pushrules.NewRuleSetEvaluator(ec, &res.RuleSets.Global)
rule, err := eval.MatchEvent(event.Event)
if err != nil {
return nil, err
}
if rule == nil {
// SPEC: If no rules match an event, the homeserver MUST NOT
// notify the Push Gateway for that event.
return nil, err
}
log.WithFields(log.Fields{
"event_id": event.EventID(),
"room_id": event.RoomID(),
"localpart": mem.Localpart,
"rule_id": rule.RuleID,
}).Tracef("Matched a push rule")
return rule.Actions, nil
}
type ruleSetEvalContext struct {
ctx context.Context
rsAPI rsapi.RoomserverInternalAPI
mem *localMembership
roomID string
roomSize int
}
func (rse *ruleSetEvalContext) UserDisplayName() string { return rse.mem.DisplayName }
func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil }
func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, error) {
req := &rsapi.QueryLatestEventsAndStateRequest{
RoomID: rse.roomID,
StateToFetch: []gomatrixserverlib.StateKeyTuple{
{EventType: gomatrixserverlib.MRoomPowerLevels},
},
}
var res rsapi.QueryLatestEventsAndStateResponse
if err := rse.rsAPI.QueryLatestEventsAndState(rse.ctx, req, &res); err != nil {
return false, err
}
for _, ev := range res.StateEvents {
if ev.Type() != gomatrixserverlib.MRoomPowerLevels {
continue
}
plc, err := gomatrixserverlib.NewPowerLevelContentFromEvent(ev.Event)
if err != nil {
return false, err
}
return plc.UserLevel(userID) >= plc.NotificationLevel(levelKey), nil
}
return true, nil
}
// localPushDevices pushes to the configured devices of a local
// user. The map keys are [url][format].
func (s *OutputStreamEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) {
pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db)
if err != nil {
return nil, "", err
}
var profileTag string
devicesByURL := make(map[string]map[string][]*pushgateway.Device, len(pusherDevices))
for _, pusherDevice := range pusherDevices {
if profileTag == "" {
profileTag = pusherDevice.Pusher.ProfileTag
}
url := pusherDevice.URL
if devicesByURL[url] == nil {
devicesByURL[url] = make(map[string][]*pushgateway.Device, 2)
}
devicesByURL[url][pusherDevice.Format] = append(devicesByURL[url][pusherDevice.Format], &pusherDevice.Device)
}
return devicesByURL, profileTag, nil
}
// notifyHTTP performs a notificatation to a Push Gateway.
func (s *OutputStreamEventConsumer) notifyHTTP(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, url, format string, devices []*pushgateway.Device, localpart, roomName string, userNumUnreadNotifs int) ([]*pushgateway.Device, error) {
logger := log.WithFields(log.Fields{
"event_id": event.EventID(),
"url": url,
"localpart": localpart,
"num_devices": len(devices),
})
var req pushgateway.NotifyRequest
switch format {
case "event_id_only":
req = pushgateway.NotifyRequest{
Notification: pushgateway.Notification{
Counts: &pushgateway.Counts{},
Devices: devices,
EventID: event.EventID(),
RoomID: event.RoomID(),
},
}
default:
req = pushgateway.NotifyRequest{
Notification: pushgateway.Notification{
Content: event.Content(),
Counts: &pushgateway.Counts{
Unread: userNumUnreadNotifs,
},
Devices: devices,
EventID: event.EventID(),
ID: event.EventID(),
RoomID: event.RoomID(),
RoomName: roomName,
Sender: event.Sender(),
Type: event.Type(),
},
}
if mem, err := event.Membership(); err == nil {
req.Notification.Membership = mem
}
if event.StateKey() != nil && *event.StateKey() == fmt.Sprintf("@%s:%s", localpart, s.cfg.Matrix.ServerName) {
req.Notification.UserIsTarget = true
}
}
logger.Debugf("Notifying push gateway %s", url)
var res pushgateway.NotifyResponse
if err := s.pgClient.Notify(ctx, url, &req, &res); err != nil {
logger.WithError(err).Errorf("Failed to notify push gateway %s", url)
return nil, err
}
logger.WithField("num_rejected", len(res.Rejected)).Tracef("Push gateway result")
if len(res.Rejected) == 0 {
return nil, nil
}
devMap := make(map[string]*pushgateway.Device, len(devices))
for _, d := range devices {
devMap[d.PushKey] = d
}
rejected := make([]*pushgateway.Device, 0, len(res.Rejected))
for _, pushKey := range res.Rejected {
d := devMap[pushKey]
if d != nil {
rejected = append(rejected, d)
}
}
return rejected, nil
}
// deleteRejectedPushers deletes the pushers associated with the given devices.
func (s *OutputStreamEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) {
log.WithFields(log.Fields{
"localpart": localpart,
"app_id0": devices[0].AppID,
"num_devices": len(devices),
}).Warnf("Deleting pushers rejected by the HTTP push gateway")
for _, d := range devices {
if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil {
log.WithFields(log.Fields{
"localpart": localpart,
}).WithError(err).Errorf("Unable to delete rejected pusher")
}
}
}

View file

@ -20,6 +20,8 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"strconv"
"time"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -27,15 +29,21 @@ import (
"github.com/matrix-org/dendrite/appservice/types" "github.com/matrix-org/dendrite/appservice/types"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
keyapi "github.com/matrix-org/dendrite/keyserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables"
) )
type UserInternalAPI struct { type UserInternalAPI struct {
DB storage.Database DB storage.Database
SyncProducer *producers.SyncAPI
DisableTLSValidation bool
ServerName gomatrixserverlib.ServerName ServerName gomatrixserverlib.ServerName
// AppServices is the list of all registered AS // AppServices is the list of all registered AS
AppServices []config.ApplicationService AppServices []config.ApplicationService
@ -595,3 +603,162 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
} }
res.Keys = result res.Keys = result
} }
func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error {
if req.Limit == 0 || req.Limit > 1000 {
req.Limit = 1000
}
var fromID int64
var err error
if req.From != "" {
fromID, err = strconv.ParseInt(req.From, 10, 64)
if err != nil {
return fmt.Errorf("QueryNotifications: parsing 'from': %w", err)
}
}
var filter tables.NotificationFilter = tables.AllNotifications
if req.Only == "highlight" {
filter = tables.HighlightNotifications
}
notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, fromID, req.Limit, filter)
if err != nil {
return err
}
if notifs == nil {
// This ensures empty is JSON-encoded as [] instead of null.
notifs = []*api.Notification{}
}
res.Notifications = notifs
if lastID >= 0 {
res.NextToken = strconv.FormatInt(lastID+1, 10)
}
return nil
}
func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.PerformPusherSetRequest, res *struct{}) error {
util.GetLogger(ctx).WithFields(logrus.Fields{
"localpart": req.Localpart,
"pushkey": req.Pusher.PushKey,
"display_name": req.Pusher.AppDisplayName,
}).Info("PerformPusherCreation")
if !req.Append {
err := a.DB.RemovePushers(ctx, req.Pusher.AppID, req.Pusher.PushKey)
if err != nil {
return err
}
}
if req.Pusher.Kind == "" {
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart)
}
if req.Pusher.PushKeyTS == 0 {
req.Pusher.PushKeyTS = gomatrixserverlib.AsTimestamp(time.Now())
}
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart)
}
func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
pushers, err := a.DB.GetPushers(ctx, req.Localpart)
if err != nil {
return err
}
for i := range pushers {
logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID)
if pushers[i].SessionID != req.SessionID {
err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart)
if err != nil {
return err
}
}
}
return nil
}
func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
var err error
res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart)
return err
}
func (a *UserInternalAPI) PerformPushRulesPut(
ctx context.Context,
req *api.PerformPushRulesPutRequest,
_ *struct{},
) error {
bs, err := json.Marshal(&req.RuleSets)
if err != nil {
return err
}
userReq := api.InputAccountDataRequest{
UserID: req.UserID,
DataType: pushRulesAccountDataType,
AccountData: json.RawMessage(bs),
}
var userRes api.InputAccountDataResponse // empty
if err := a.InputAccountData(ctx, &userReq, &userRes); err != nil {
return err
}
if err := a.SyncProducer.SendAccountData(req.UserID, "" /* roomID */, pushRulesAccountDataType); err != nil {
util.GetLogger(ctx).WithError(err).Errorf("syncProducer.SendData failed")
}
return nil
}
func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
userReq := api.QueryAccountDataRequest{
UserID: req.UserID,
DataType: pushRulesAccountDataType,
}
var userRes api.QueryAccountDataResponse
if err := a.QueryAccountData(ctx, &userReq, &userRes); err != nil {
return err
}
bs, ok := userRes.GlobalAccountData[pushRulesAccountDataType]
if ok {
// Legacy Dendrite users will have completely empty push rules, so we should
// detect that situation and set some defaults.
var rules struct {
G struct {
Content []json.RawMessage `json:"content"`
Override []json.RawMessage `json:"override"`
Room []json.RawMessage `json:"room"`
Sender []json.RawMessage `json:"sender"`
Underride []json.RawMessage `json:"underride"`
} `json:"global"`
}
if err := json.Unmarshal([]byte(bs), &rules); err == nil {
count := len(rules.G.Content) + len(rules.G.Override) +
len(rules.G.Room) + len(rules.G.Sender) + len(rules.G.Underride)
ok = count > 0
}
}
if !ok {
// If we didn't find any default push rules then we should just generate some
// fresh ones.
localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
return fmt.Errorf("failed to split user ID %q for push rules", req.UserID)
}
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, a.ServerName)
prbs, err := json.Marshal(pushRuleSets)
if err != nil {
return fmt.Errorf("failed to marshal default push rules: %w", err)
}
if err := a.DB.SaveAccountData(ctx, localpart, "", pushRulesAccountDataType, json.RawMessage(prbs)); err != nil {
return fmt.Errorf("failed to save default push rules: %w", err)
}
res.RuleSets = pushRuleSets
return nil
}
var data pushrules.AccountRuleSets
if err := json.Unmarshal([]byte(bs), &data); err != nil {
util.GetLogger(ctx).WithError(err).Error("json.Unmarshal of push rules failed")
return err
}
res.RuleSets = &data
return nil
}
const pushRulesAccountDataType = "m.push_rules"

View file

@ -37,6 +37,9 @@ const (
PerformAccountDeactivationPath = "/userapi/performAccountDeactivation" PerformAccountDeactivationPath = "/userapi/performAccountDeactivation"
PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation" PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation"
PerformKeyBackupPath = "/userapi/performKeyBackup" PerformKeyBackupPath = "/userapi/performKeyBackup"
PerformPusherSetPath = "/pushserver/performPusherSet"
PerformPusherDeletionPath = "/pushserver/performPusherDeletion"
PerformPushRulesPutPath = "/pushserver/performPushRulesPut"
QueryKeyBackupPath = "/userapi/queryKeyBackup" QueryKeyBackupPath = "/userapi/queryKeyBackup"
QueryProfilePath = "/userapi/queryProfile" QueryProfilePath = "/userapi/queryProfile"
@ -46,6 +49,9 @@ const (
QueryDeviceInfosPath = "/userapi/queryDeviceInfos" QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
QuerySearchProfilesPath = "/userapi/querySearchProfiles" QuerySearchProfilesPath = "/userapi/querySearchProfiles"
QueryOpenIDTokenPath = "/userapi/queryOpenIDToken" QueryOpenIDTokenPath = "/userapi/queryOpenIDToken"
QueryPushersPath = "/pushserver/queryPushers"
QueryPushRulesPath = "/pushserver/queryPushRules"
QueryNotificationsPath = "/pushserver/queryNotifications"
) )
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@ -249,3 +255,58 @@ func (h *httpUserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.Query
res.Error = err.Error() res.Error = err.Error()
} }
} }
func (h *httpUserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryNotifications")
defer span.Finish()
return httputil.PostJSON(ctx, span, h.httpClient, h.apiURL+QueryNotificationsPath, req, res)
}
func (h *httpUserInternalAPI) PerformPusherSet(
ctx context.Context,
request *api.PerformPusherSetRequest,
response *struct{},
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherSet")
defer span.Finish()
apiURL := h.apiURL + PerformPusherSetPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPusherDeletion")
defer span.Finish()
apiURL := h.apiURL + PerformPusherDeletionPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushers")
defer span.Finish()
apiURL := h.apiURL + QueryPushersPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
func (h *httpUserInternalAPI) PerformPushRulesPut(
ctx context.Context,
request *api.PerformPushRulesPutRequest,
response *struct{},
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPushRulesPut")
defer span.Finish()
apiURL := h.apiURL + PerformPushRulesPutPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPushRules")
defer span.Finish()
apiURL := h.apiURL + QueryPushRulesPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}

View file

@ -265,4 +265,86 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(QueryNotificationsPath,
httputil.MakeInternalAPI("queryNotifications", func(req *http.Request) util.JSONResponse {
var request api.QueryNotificationsRequest
var response api.QueryNotificationsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryNotifications(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformPusherSetPath,
httputil.MakeInternalAPI("performPusherSet", func(req *http.Request) util.JSONResponse {
request := api.PerformPusherSetRequest{}
response := struct{}{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformPusherSet(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformPusherDeletionPath,
httputil.MakeInternalAPI("performPusherDeletion", func(req *http.Request) util.JSONResponse {
request := api.PerformPusherDeletionRequest{}
response := struct{}{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformPusherDeletion(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryPushersPath,
httputil.MakeInternalAPI("queryPushers", func(req *http.Request) util.JSONResponse {
request := api.QueryPushersRequest{}
response := api.QueryPushersResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryPushers(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformPushRulesPutPath,
httputil.MakeInternalAPI("performPushRulesPut", func(req *http.Request) util.JSONResponse {
request := api.PerformPushRulesPutRequest{}
response := struct{}{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformPushRulesPut(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryPushRulesPath,
httputil.MakeInternalAPI("queryPushRules", func(req *http.Request) util.JSONResponse {
request := api.QueryPushRulesRequest{}
response := api.QueryPushRulesResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.QueryPushRules(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
} }

View file

@ -0,0 +1,104 @@
package producers
import (
"context"
"encoding/json"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus"
)
type JetStreamPublisher interface {
PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, error)
}
// SyncAPI produces messages for the Sync API server to consume.
type SyncAPI struct {
db storage.Database
producer JetStreamPublisher
clientDataTopic string
notificationDataTopic string
}
func NewSyncAPI(db storage.Database, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI {
return &SyncAPI{
db: db,
producer: js,
clientDataTopic: clientDataTopic,
notificationDataTopic: notificationDataTopic,
}
}
// SendAccountData sends account data to the Sync API server.
func (p *SyncAPI) SendAccountData(userID string, roomID string, dataType string) error {
m := &nats.Msg{
Subject: p.clientDataTopic,
Header: nats.Header{},
}
m.Header.Set(jetstream.UserID, userID)
var err error
m.Data, err = json.Marshal(eventutil.AccountData{
RoomID: roomID,
Type: dataType,
})
if err != nil {
return err
}
log.WithFields(log.Fields{
"user_id": userID,
"room_id": roomID,
"data_type": dataType,
}).Tracef("Producing to topic '%s'", p.clientDataTopic)
_, err = p.producer.PublishMsg(m)
return err
}
// GetAndSendNotificationData reads the database and sends data about unread
// notifications to the Sync API server.
func (p *SyncAPI) GetAndSendNotificationData(ctx context.Context, userID, roomID string) error {
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return err
}
ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, roomID)
if err != nil {
return err
}
return p.sendNotificationData(userID, &eventutil.NotificationData{
RoomID: roomID,
UnreadHighlightCount: int(nhighlight),
UnreadNotificationCount: int(ntotal),
})
}
// sendNotificationData sends data about unread notifications to the Sync API server.
func (p *SyncAPI) sendNotificationData(userID string, data *eventutil.NotificationData) error {
m := &nats.Msg{
Subject: p.notificationDataTopic,
Header: nats.Header{},
}
m.Header.Set(jetstream.UserID, userID)
var err error
m.Data, err = json.Marshal(data)
if err != nil {
return err
}
log.WithFields(log.Fields{
"user_id": userID,
"room_id": data.RoomID,
}).Tracef("Producing to topic '%s'", p.clientDataTopic)
_, err = p.producer.PublishMsg(m)
return err
}

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
) )
type Database interface { type Database interface {
@ -89,6 +90,18 @@ type Database interface {
// GetLoginTokenDataByToken returns the data associated with the given token. // GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows. // May return sql.ErrNoRows.
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error)
SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error)
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)
UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error
GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error)
RemovePusher(ctx context.Context, appid, pushkey, localpart string) error
RemovePushers(ctx context.Context, appid, pushkey string) error
} }
// Err3PIDInUse is the error returned when trying to save an association involving // Err3PIDInUse is the error returned when trying to save an association involving

View file

@ -0,0 +1,219 @@
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
type notificationsStatements struct {
insertStmt *sql.Stmt
deleteUpToStmt *sql.Stmt
updateReadStmt *sql.Stmt
selectStmt *sql.Stmt
selectCountStmt *sql.Stmt
selectRoomCountsStmt *sql.Stmt
}
const notificationSchema = `
CREATE TABLE IF NOT EXISTS userapi_notifications (
id BIGSERIAL PRIMARY KEY,
localpart TEXT NOT NULL,
room_id TEXT NOT NULL,
event_id TEXT NOT NULL,
stream_pos BIGINT NOT NULL,
ts_ms BIGINT NOT NULL,
highlight BOOLEAN NOT NULL,
notification_json TEXT NOT NULL,
read BOOLEAN NOT NULL DEFAULT FALSE
);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
`
const insertNotificationSQL = "" +
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
const deleteNotificationsUpToSQL = "" +
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
const updateNotificationReadSQL = "" +
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
const selectNotificationSQL = "" +
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
") AND NOT read ORDER BY localpart, id LIMIT $4"
const selectNotificationCountSQL = "" +
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
"(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
") AND NOT read"
const selectRoomNotificationCountsSQL = "" +
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
"WHERE localpart = $1 AND room_id = $2 AND NOT read"
func NewPostgresNotificationTable(db *sql.DB) (tables.NotificationTable, error) {
s := &notificationsStatements{}
_, err := db.Exec(notificationSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertStmt, insertNotificationSQL},
{&s.deleteUpToStmt, deleteNotificationsUpToSQL},
{&s.updateReadStmt, updateNotificationReadSQL},
{&s.selectStmt, selectNotificationSQL},
{&s.selectCountStmt, selectNotificationCountSQL},
{&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL},
}.Prepare(db)
}
// Insert inserts a notification into the database.
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS
nn := *n
// Clears out fields that have their own columns to (1) shrink the
// data and (2) avoid difficult-to-debug inconsistency bugs.
nn.RoomID = ""
nn.TS, nn.Read = 0, false
bs, err := json.Marshal(nn)
if err != nil {
return err
}
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
return err
}
// DeleteUpTo deletes all previous notifications, up to and including the event.
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
if err != nil {
return false, err
}
nrows, err := res.RowsAffected()
if err != nil {
return true, err
}
log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows)
return nrows > 0, nil
}
// UpdateRead updates the "read" value for an event.
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
if err != nil {
return false, err
}
nrows, err := res.RowsAffected()
if err != nil {
return true, err
}
log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows)
return nrows > 0, nil
}
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
if err != nil {
return nil, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
var maxID int64 = -1
var notifs []*api.Notification
for rows.Next() {
var id int64
var roomID string
var ts gomatrixserverlib.Timestamp
var read bool
var jsonStr string
err = rows.Scan(
&id,
&roomID,
&ts,
&read,
&jsonStr)
if err != nil {
return nil, 0, err
}
var n api.Notification
err := json.Unmarshal([]byte(jsonStr), &n)
if err != nil {
return nil, 0, err
}
n.RoomID = roomID
n.TS = ts
n.Read = read
notifs = append(notifs, &n)
if maxID < id {
maxID = id
}
}
return notifs, maxID, rows.Err()
}
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) {
rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter))
if err != nil {
return 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var count int64
if err := rows.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
return 0, rows.Err()
}
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) {
rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID)
if err != nil {
return 0, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var total, highlight int64
if err := rows.Scan(&total, &highlight); err != nil {
return 0, 0, err
}
return total, highlight, nil
}
return 0, 0, rows.Err()
}

View file

@ -0,0 +1,157 @@
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
const pushersSchema = `
CREATE TABLE IF NOT EXISTS userapi_pushers (
id BIGSERIAL PRIMARY KEY,
-- The Matrix user ID localpart for this pusher
localpart TEXT NOT NULL,
session_id BIGINT DEFAULT NULL,
profile_tag TEXT,
kind TEXT NOT NULL,
app_id TEXT NOT NULL,
app_display_name TEXT NOT NULL,
device_display_name TEXT NOT NULL,
pushkey TEXT NOT NULL,
pushkey_ts_ms BIGINT NOT NULL DEFAULT 0,
lang TEXT NOT NULL,
data TEXT NOT NULL
);
-- For faster deleting by app_id, pushkey pair.
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
-- For faster retrieving by localpart.
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
-- Pushkey must be unique for a given user and app.
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
`
const insertPusherSQL = "" +
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
const selectPushersSQL = "" +
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
const deletePusherSQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
const deletePushersByAppIdAndPushKeySQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
func NewPostgresPusherTable(db *sql.DB) (tables.PusherTable, error) {
s := &pushersStatements{}
_, err := db.Exec(pushersSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertPusherStmt, insertPusherSQL},
{&s.selectPushersStmt, selectPushersSQL},
{&s.deletePusherStmt, deletePusherSQL},
{&s.deletePushersByAppIdAndPushKeyStmt, deletePushersByAppIdAndPushKeySQL},
}.Prepare(db)
}
type pushersStatements struct {
insertPusherStmt *sql.Stmt
selectPushersStmt *sql.Stmt
deletePusherStmt *sql.Stmt
deletePushersByAppIdAndPushKeyStmt *sql.Stmt
}
// insertPusher creates a new pusher.
// Returns an error if the user already has a pusher with the given pusher pushkey.
// Returns nil error success.
func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64,
pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
) error {
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
logrus.Debugf("Created pusher %d", session_id)
return err
}
func (s *pushersStatements) SelectPushers(
ctx context.Context, txn *sql.Tx, localpart string,
) ([]api.Pusher, error) {
pushers := []api.Pusher{}
rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart)
if err != nil {
return pushers, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectPushers: rows.close() failed")
for rows.Next() {
var pusher api.Pusher
var data []byte
err = rows.Scan(
&pusher.SessionID,
&pusher.PushKey,
&pusher.PushKeyTS,
&pusher.Kind,
&pusher.AppID,
&pusher.AppDisplayName,
&pusher.DeviceDisplayName,
&pusher.ProfileTag,
&pusher.Language,
&data)
if err != nil {
return pushers, err
}
err := json.Unmarshal(data, &pusher.Data)
if err != nil {
return pushers, err
}
pushers = append(pushers, pusher)
}
logrus.Debugf("Database returned %d pushers", len(pushers))
return pushers, rows.Err()
}
// deletePusher removes a single pusher by pushkey and user localpart.
func (s *pushersStatements) DeletePusher(
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
) error {
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
return err
}
func (s *pushersStatements) DeletePushers(
ctx context.Context, txn *sql.Tx, appid, pushkey string,
) error {
_, err := sqlutil.TxStmt(txn, s.deletePushersByAppIdAndPushKeyStmt).ExecContext(ctx, appid, pushkey)
return err
}

View file

@ -85,6 +85,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil { if err != nil {
return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err) return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err)
} }
pusherTable, err := NewPostgresPusherTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresPusherTable: %w", err)
}
notificationsTable, err := NewPostgresNotificationTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err)
}
return &shared.Database{ return &shared.Database{
AccountDatas: accountDataTable, AccountDatas: accountDataTable,
Accounts: accountsTable, Accounts: accountsTable,
@ -95,6 +103,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
OpenIDTokens: openIDTable, OpenIDTokens: openIDTable,
Profiles: profilesTable, Profiles: profilesTable,
ThreePIDs: threePIDTable, ThreePIDs: threePIDTable,
Pushers: pusherTable,
Notifications: notificationsTable,
ServerName: serverName, ServerName: serverName,
DB: db, DB: db,
Writer: sqlutil.NewDummyWriter(), Writer: sqlutil.NewDummyWriter(),

View file

@ -29,6 +29,7 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables" "github.com/matrix-org/dendrite/userapi/storage/tables"
@ -47,6 +48,8 @@ type Database struct {
KeyBackupVersions tables.KeyBackupVersionTable KeyBackupVersions tables.KeyBackupVersionTable
Devices tables.DevicesTable Devices tables.DevicesTable
LoginTokens tables.LoginTokenTable LoginTokens tables.LoginTokenTable
Notifications tables.NotificationTable
Pushers tables.PusherTable
LoginTokenLifetime time.Duration LoginTokenLifetime time.Duration
ServerName gomatrixserverlib.ServerName ServerName gomatrixserverlib.ServerName
BcryptCost int BcryptCost int
@ -160,15 +163,12 @@ func (d *Database) createAccount(
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil { if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
return nil, err return nil, err
} }
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
"global": { prbs, err := json.Marshal(pushRuleSets)
"content": [], if err != nil {
"override": [], return nil, err
"room": [],
"sender": [],
"underride": []
} }
}`)); err != nil { if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
return nil, err return nil, err
} }
return account, nil return account, nil
@ -670,3 +670,94 @@ func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.LoginTokens.SelectLoginToken(ctx, token) return d.LoginTokens.SelectLoginToken(ctx, token)
} }
func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
})
}
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos)
return err
})
return
}
func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b)
return err
})
return
}
func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
return d.Notifications.Select(ctx, nil, localpart, fromID, limit, filter)
}
func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) {
return d.Notifications.SelectCount(ctx, nil, localpart, filter)
}
func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) {
return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID)
}
func (d *Database) UpsertPusher(
ctx context.Context, p api.Pusher, localpart string,
) error {
data, err := json.Marshal(p.Data)
if err != nil {
return err
}
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Pushers.InsertPusher(
ctx, txn,
p.SessionID,
p.PushKey,
p.PushKeyTS,
p.Kind,
p.AppID,
p.AppDisplayName,
p.DeviceDisplayName,
p.ProfileTag,
p.Language,
string(data),
localpart)
})
}
// GetPushers returns the pushers matching the given localpart.
func (d *Database) GetPushers(
ctx context.Context, localpart string,
) ([]api.Pusher, error) {
return d.Pushers.SelectPushers(ctx, nil, localpart)
}
// RemovePusher deletes one pusher
// Invoked when `append` is true and `kind` is null in
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
func (d *Database) RemovePusher(
ctx context.Context, appid, pushkey, localpart string,
) error {
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart)
if err == sql.ErrNoRows {
return nil
}
return err
})
}
// RemovePushers deletes all pushers that match given App Id and Push Key pair.
// Invoked when `append` parameter is false in
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
func (d *Database) RemovePushers(
ctx context.Context, appid, pushkey string,
) error {
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
return d.Pushers.DeletePushers(ctx, txn, appid, pushkey)
})
}

View file

@ -0,0 +1,219 @@
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
type notificationsStatements struct {
insertStmt *sql.Stmt
deleteUpToStmt *sql.Stmt
updateReadStmt *sql.Stmt
selectStmt *sql.Stmt
selectCountStmt *sql.Stmt
selectRoomCountsStmt *sql.Stmt
}
const notificationSchema = `
CREATE TABLE IF NOT EXISTS userapi_notifications (
id INTEGER PRIMARY KEY AUTOINCREMENT,
localpart TEXT NOT NULL,
room_id TEXT NOT NULL,
event_id TEXT NOT NULL,
stream_pos BIGINT NOT NULL,
ts_ms BIGINT NOT NULL,
highlight BOOLEAN NOT NULL,
notification_json TEXT NOT NULL,
read BOOLEAN NOT NULL DEFAULT FALSE
);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
`
const insertNotificationSQL = "" +
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
const deleteNotificationsUpToSQL = "" +
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
const updateNotificationReadSQL = "" +
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
const selectNotificationSQL = "" +
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
") AND NOT read ORDER BY localpart, id LIMIT $4"
const selectNotificationCountSQL = "" +
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
"(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
") AND NOT read"
const selectRoomNotificationCountsSQL = "" +
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
"WHERE localpart = $1 AND room_id = $2 AND NOT read"
func NewSQLiteNotificationTable(db *sql.DB) (tables.NotificationTable, error) {
s := &notificationsStatements{}
_, err := db.Exec(notificationSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertStmt, insertNotificationSQL},
{&s.deleteUpToStmt, deleteNotificationsUpToSQL},
{&s.updateReadStmt, updateNotificationReadSQL},
{&s.selectStmt, selectNotificationSQL},
{&s.selectCountStmt, selectNotificationCountSQL},
{&s.selectRoomCountsStmt, selectRoomNotificationCountsSQL},
}.Prepare(db)
}
// Insert inserts a notification into the database.
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error {
roomID, tsMS := n.RoomID, n.TS
nn := *n
// Clears out fields that have their own columns to (1) shrink the
// data and (2) avoid difficult-to-debug inconsistency bugs.
nn.RoomID = ""
nn.TS, nn.Read = 0, false
bs, err := json.Marshal(nn)
if err != nil {
return err
}
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
return err
}
// DeleteUpTo deletes all previous notifications, up to and including the event.
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
if err != nil {
return false, err
}
nrows, err := res.RowsAffected()
if err != nil {
return true, err
}
log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("DeleteUpTo: %d rows affected", nrows)
return nrows > 0, nil
}
// UpdateRead updates the "read" value for an event.
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error) {
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
if err != nil {
return false, err
}
nrows, err := res.RowsAffected()
if err != nil {
return true, err
}
log.WithFields(log.Fields{"localpart": localpart, "room_id": roomID, "stream_pos": pos}).Tracef("UpdateRead: %d rows affected", nrows)
return nrows > 0, nil
}
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
if err != nil {
return nil, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
var maxID int64 = -1
var notifs []*api.Notification
for rows.Next() {
var id int64
var roomID string
var ts gomatrixserverlib.Timestamp
var read bool
var jsonStr string
err = rows.Scan(
&id,
&roomID,
&ts,
&read,
&jsonStr)
if err != nil {
return nil, 0, err
}
var n api.Notification
err := json.Unmarshal([]byte(jsonStr), &n)
if err != nil {
return nil, 0, err
}
n.RoomID = roomID
n.TS = ts
n.Read = read
notifs = append(notifs, &n)
if maxID < id {
maxID = id
}
}
return notifs, maxID, rows.Err()
}
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (int64, error) {
rows, err := sqlutil.TxStmt(txn, s.selectCountStmt).QueryContext(ctx, localpart, uint32(filter))
if err != nil {
return 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var count int64
if err := rows.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
return 0, rows.Err()
}
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) {
rows, err := sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryContext(ctx, localpart, roomID)
if err != nil {
return 0, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "notifications.Select: rows.Close() failed")
if rows.Next() {
var total, highlight int64
if err := rows.Scan(&total, &highlight); err != nil {
return 0, 0, err
}
return total, highlight, nil
}
return 0, 0, rows.Err()
}

View file

@ -0,0 +1,157 @@
// Copyright 2021 Dan Peleg <dan@globekeeper.com>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
const pushersSchema = `
CREATE TABLE IF NOT EXISTS userapi_pushers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- The Matrix user ID localpart for this pusher
localpart TEXT NOT NULL,
session_id BIGINT DEFAULT NULL,
profile_tag TEXT,
kind TEXT NOT NULL,
app_id TEXT NOT NULL,
app_display_name TEXT NOT NULL,
device_display_name TEXT NOT NULL,
pushkey TEXT NOT NULL,
pushkey_ts_ms BIGINT NOT NULL DEFAULT 0,
lang TEXT NOT NULL,
data TEXT NOT NULL
);
-- For faster deleting by app_id, pushkey pair.
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
-- For faster retrieving by localpart.
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
-- Pushkey must be unique for a given user and app.
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
`
const insertPusherSQL = "" +
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
const selectPushersSQL = "" +
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
const deletePusherSQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
const deletePushersByAppIdAndPushKeySQL = "" +
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
func NewSQLitePusherTable(db *sql.DB) (tables.PusherTable, error) {
s := &pushersStatements{}
_, err := db.Exec(pushersSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertPusherStmt, insertPusherSQL},
{&s.selectPushersStmt, selectPushersSQL},
{&s.deletePusherStmt, deletePusherSQL},
{&s.deletePushersByAppIdAndPushKeyStmt, deletePushersByAppIdAndPushKeySQL},
}.Prepare(db)
}
type pushersStatements struct {
insertPusherStmt *sql.Stmt
selectPushersStmt *sql.Stmt
deletePusherStmt *sql.Stmt
deletePushersByAppIdAndPushKeyStmt *sql.Stmt
}
// insertPusher creates a new pusher.
// Returns an error if the user already has a pusher with the given pusher pushkey.
// Returns nil error success.
func (s *pushersStatements) InsertPusher(
ctx context.Context, txn *sql.Tx, session_id int64,
pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
) error {
_, err := s.insertPusherStmt.ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
logrus.Debugf("Created pusher %d", session_id)
return err
}
func (s *pushersStatements) SelectPushers(
ctx context.Context, txn *sql.Tx, localpart string,
) ([]api.Pusher, error) {
pushers := []api.Pusher{}
rows, err := s.selectPushersStmt.QueryContext(ctx, localpart)
if err != nil {
return pushers, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectPushers: rows.close() failed")
for rows.Next() {
var pusher api.Pusher
var data []byte
err = rows.Scan(
&pusher.SessionID,
&pusher.PushKey,
&pusher.PushKeyTS,
&pusher.Kind,
&pusher.AppID,
&pusher.AppDisplayName,
&pusher.DeviceDisplayName,
&pusher.ProfileTag,
&pusher.Language,
&data)
if err != nil {
return pushers, err
}
err := json.Unmarshal(data, &pusher.Data)
if err != nil {
return pushers, err
}
pushers = append(pushers, pusher)
}
logrus.Debugf("Database returned %d pushers", len(pushers))
return pushers, rows.Err()
}
// deletePusher removes a single pusher by pushkey and user localpart.
func (s *pushersStatements) DeletePusher(
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
) error {
_, err := s.deletePusherStmt.ExecContext(ctx, appid, pushkey, localpart)
return err
}
func (s *pushersStatements) DeletePushers(
ctx context.Context, txn *sql.Tx, appid, pushkey string,
) error {
_, err := s.deletePushersByAppIdAndPushKeyStmt.ExecContext(ctx, appid, pushkey)
return err
}

View file

@ -86,6 +86,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil { if err != nil {
return nil, fmt.Errorf("NewSQLiteThreePIDTable: %w", err) return nil, fmt.Errorf("NewSQLiteThreePIDTable: %w", err)
} }
pusherTable, err := NewSQLitePusherTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresPusherTable: %w", err)
}
notificationsTable, err := NewSQLiteNotificationTable(db)
if err != nil {
return nil, fmt.Errorf("NewPostgresNotificationTable: %w", err)
}
return &shared.Database{ return &shared.Database{
AccountDatas: accountDataTable, AccountDatas: accountDataTable,
Accounts: accountsTable, Accounts: accountsTable,
@ -96,6 +104,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
OpenIDTokens: openIDTable, OpenIDTokens: openIDTable,
Profiles: profilesTable, Profiles: profilesTable,
ThreePIDs: threePIDTable, ThreePIDs: threePIDTable,
Pushers: pusherTable,
Notifications: notificationsTable,
ServerName: serverName, ServerName: serverName,
DB: db, DB: db,
Writer: sqlutil.NewExclusiveWriter(), Writer: sqlutil.NewExclusiveWriter(),

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
) )
type AccountDataTable interface { type AccountDataTable interface {
@ -93,3 +94,42 @@ type ThreePIDTable interface {
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error) InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error)
DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error)
} }
type PusherTable interface {
InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS gomatrixserverlib.Timestamp, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error
SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error)
DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error
DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error
}
type NotificationTable interface {
Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos int64, highlight bool, n *api.Notification) error
DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64) (affected bool, _ error)
UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos int64, v bool) (affected bool, _ error)
Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error)
SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error)
SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error)
}
type NotificationFilter uint32
const (
// HighlightNotifications returns notifications that had a
// "highlight" tweak assigned to them from evaluating push rules.
HighlightNotifications NotificationFilter = 1 << iota
// NonHighlightNotifications returns notifications that don't
// match HighlightNotifications.
NonHighlightNotifications
// NoNotifications is a filter to exclude all types of
// notifications. It's useful as a zero value, but isn't likely to
// be used in a call to Notifications.Select*.
NoNotifications NotificationFilter = 0
// AllNotifications is a filter to include all types of
// notifications in Notifications.Select*. Note that PostgreSQL
// balks if this doesn't fit in INTEGER, even though we use
// uint32.
AllNotifications NotificationFilter = (1 << 31) - 1
)

View file

@ -18,11 +18,17 @@ import (
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/internal/pushgateway"
keyapi "github.com/matrix-org/dendrite/keyserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/consumers"
"github.com/matrix-org/dendrite/userapi/internal" "github.com/matrix-org/dendrite/userapi/internal"
"github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/inthttp"
"github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -36,26 +42,49 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
// NewInternalAPI returns a concerete implementation of the internal API. Callers // NewInternalAPI returns a concerete implementation of the internal API. Callers
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI( func NewInternalAPI(
accountDB storage.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, base *base.BaseDendrite, db storage.Database, cfg *config.UserAPI,
appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI,
rsAPI rsapi.RoomserverInternalAPI, pgClient pushgateway.Client,
) api.UserInternalAPI { ) api.UserInternalAPI {
db, err := storage.NewDatabase(&cfg.AccountDatabase, cfg.Matrix.ServerName, cfg.BCryptCost, int64(api.DefaultLoginTokenLifetime*time.Millisecond), api.DefaultLoginTokenLifetime) db, err := storage.NewDatabase(&cfg.AccountDatabase, cfg.Matrix.ServerName, cfg.BCryptCost, int64(api.DefaultLoginTokenLifetime*time.Millisecond), api.DefaultLoginTokenLifetime)
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to connect to device db") logrus.WithError(err).Panicf("failed to connect to device db")
} }
return newInternalAPI(db, cfg, appServices, keyAPI) js := jetstream.Prepare(&cfg.Matrix.JetStream)
}
func newInternalAPI( syncProducer := producers.NewSyncAPI(
db storage.Database, db, js,
cfg *config.UserAPI, // TODO: user API should handle syncs for account data. Right now,
appServices []config.ApplicationService, // it's handled by clientapi, and hence uses its topic. When user
keyAPI keyapi.KeyInternalAPI, // API handles it for all account data, we can remove it from
) api.UserInternalAPI { // here.
return &internal.UserInternalAPI{ cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData),
cfg.Matrix.JetStream.TopicFor(jetstream.OutputNotificationData),
)
userAPI := &internal.UserInternalAPI{
DB: db, DB: db,
SyncProducer: syncProducer,
ServerName: cfg.Matrix.ServerName, ServerName: cfg.Matrix.ServerName,
AppServices: appServices, AppServices: appServices,
KeyAPI: keyAPI, KeyAPI: keyAPI,
DisableTLSValidation: cfg.PushGatewayDisableTLSValidation,
} }
readConsumer := consumers.NewOutputReadUpdateConsumer(
base.ProcessContext, cfg, js, db, pgClient, userAPI, syncProducer,
)
if err := readConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start user API read update consumer")
}
eventConsumer := consumers.NewOutputStreamEventConsumer(
base.ProcessContext, cfg, js, db, pgClient, userAPI, rsAPI, syncProducer,
)
if err := eventConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start user API streamed event consumer")
}
return userAPI
} }

View file

@ -30,6 +30,7 @@ import (
"github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/internal/test"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/internal"
"github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/inthttp"
"github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage"
) )
@ -62,7 +63,10 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, s
}, },
} }
return newInternalAPI(accountDB, cfg, nil, nil), accountDB return &internal.UserInternalAPI{
DB: accountDB,
ServerName: cfg.Matrix.ServerName,
}, accountDB
} }
func TestQueryProfile(t *testing.T) { func TestQueryProfile(t *testing.T) {

100
userapi/util/devices.go Normal file
View file

@ -0,0 +1,100 @@
package util
import (
"context"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage"
log "github.com/sirupsen/logrus"
)
type PusherDevice struct {
Device pushgateway.Device
Pusher *api.Pusher
URL string
Format string
}
// GetPushDevices pushes to the configured devices of a local user.
func GetPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
pushers, err := db.GetPushers(ctx, localpart)
if err != nil {
return nil, err
}
devices := make([]*PusherDevice, 0, len(pushers))
for _, pusher := range pushers {
var url, format string
data := pusher.Data
switch pusher.Kind {
case api.EmailKind:
url = "mailto:"
case api.HTTPKind:
// TODO: The spec says only event_id_only is supported,
// but Sytests assume "" means "full notification".
fmtIface := pusher.Data["format"]
var ok bool
format, ok = fmtIface.(string)
if ok && format != "event_id_only" {
log.WithFields(log.Fields{
"localpart": localpart,
"app_id": pusher.AppID,
}).Errorf("Only data.format event_id_only or empty is supported")
continue
}
urlIface := pusher.Data["url"]
url, ok = urlIface.(string)
if !ok {
log.WithFields(log.Fields{
"localpart": localpart,
"app_id": pusher.AppID,
}).Errorf("No data.url configured for HTTP Pusher")
continue
}
data = mapWithout(data, "url")
default:
log.WithFields(log.Fields{
"localpart": localpart,
"app_id": pusher.AppID,
"kind": pusher.Kind,
}).Errorf("Unhandled pusher kind")
continue
}
devices = append(devices, &PusherDevice{
Device: pushgateway.Device{
AppID: pusher.AppID,
Data: data,
PushKey: pusher.PushKey,
PushKeyTS: pusher.PushKeyTS,
Tweaks: tweaks,
},
Pusher: &pusher,
URL: url,
Format: format,
})
}
return devices, nil
}
// mapWithout returns a shallow copy of the map, without the given
// key. Returns nil if the resulting map is empty.
func mapWithout(m map[string]interface{}, key string) map[string]interface{} {
ret := make(map[string]interface{}, len(m))
for k, v := range m {
// The specification says we do not send "url".
if k == key {
continue
}
ret[k] = v
}
if len(ret) == 0 {
return nil
}
return ret
}

76
userapi/util/notify.go Normal file
View file

@ -0,0 +1,76 @@
package util
import (
"context"
"strings"
"time"
"github.com/matrix-org/dendrite/internal/pushgateway"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables"
log "github.com/sirupsen/logrus"
)
// NotifyUserCountsAsync sends notifications to a local user's
// notification destinations. Database lookups run synchronously, but
// a single goroutine is started when talking to the Push
// gateways. There is no way to know when the background goroutine has
// finished.
func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, db storage.Database) error {
pusherDevices, err := GetPushDevices(ctx, localpart, nil, db)
if err != nil {
return err
}
if len(pusherDevices) == 0 {
return nil
}
userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, tables.AllNotifications)
if err != nil {
return err
}
log.WithFields(log.Fields{
"localpart": localpart,
"app_id0": pusherDevices[0].Device.AppID,
"pushkey": pusherDevices[0].Device.PushKey,
}).Tracef("Notifying HTTP push gateway about notification counts")
// TODO: think about bounding this to one per user, and what
// ordering guarantees we must provide.
go func() {
// This background processing cannot be tied to a request.
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// TODO: we could batch all devices with the same URL, but
// Sytest requires consumers/roomserver.go to do it
// one-by-one, so we do the same here.
for _, pusherDevice := range pusherDevices {
// TODO: support "email".
if !strings.HasPrefix(pusherDevice.URL, "http") {
continue
}
req := pushgateway.NotifyRequest{
Notification: pushgateway.Notification{
Counts: &pushgateway.Counts{
Unread: int(userNumUnreadNotifs),
},
Devices: []*pushgateway.Device{&pusherDevice.Device},
},
}
if err := pgClient.Notify(ctx, pusherDevice.URL, &req, &pushgateway.NotifyResponse{}); err != nil {
log.WithFields(log.Fields{
"localpart": localpart,
"app_id0": pusherDevice.Device.AppID,
"pushkey": pusherDevice.Device.PushKey,
}).WithError(err).Error("HTTP push gateway request failed")
return
}
}
}()
return nil
}