PDU Sender split (#3100)

Initial cut of splitting PDU Sender into SenderID & looking up UserID where required.
This commit is contained in:
devonh 2023-06-06 20:55:18 +00:00 committed by GitHub
parent 725ff5567d
commit 7a1fd7f512
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
66 changed files with 580 additions and 189 deletions

View file

@ -108,7 +108,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms
}
if s.cfg.Matrix.ReportStats.Enabled {
go s.storeMessageStats(ctx, event.Type(), event.Sender(), event.RoomID())
go s.storeMessageStats(ctx, event.Type(), event.SenderID(), event.RoomID())
}
log.WithFields(log.Fields{
@ -301,7 +301,12 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst
switch {
case event.Type() == spec.MRoomMember:
cevent := synctypes.ToClientEvent(event, synctypes.FormatAll)
sender := spec.UserID{}
userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if queryErr == nil && userID != nil {
sender = *userID
}
cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender)
var member *localMembership
member, err = newLocalMembership(&cevent)
if err != nil {
@ -529,12 +534,17 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
return fmt.Errorf("s.localPushDevices: %w", err)
}
sender := spec.UserID{}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
n := &api.Notification{
Actions: actions,
// UNSPEC: the spec doesn't say this is a ClientEvent, but the
// fields seem to match. room_id should be missing, which
// matches the behaviour of FormatSync.
Event: synctypes.ToClientEvent(event, synctypes.FormatSync),
Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender),
// TODO: this is per-device, but it's not part of the primary
// key. So inserting one notification per profile tag doesn't
// make sense. What is this supposed to be? Sytests require it
@ -615,7 +625,12 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
// evaluatePushRules fetches and evaluates the push rules of a local
// user. Returns actions (including dont_notify).
func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
if event.Sender() == mem.UserID {
user := ""
sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err == nil {
user = sender.String()
}
if user == mem.UserID {
// SPEC: Homeservers MUST NOT notify the Push Gateway for
// events that the user has sent themselves.
return nil, nil
@ -632,9 +647,8 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
if err != nil {
return nil, err
}
sender := event.Sender()
if _, ok := ignored.List[sender]; ok {
return nil, fmt.Errorf("user %s is ignored", sender)
if _, ok := ignored.List[sender.String()]; ok {
return nil, fmt.Errorf("user %s is ignored", sender.String())
}
}
ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart, mem.Domain)
@ -650,7 +664,9 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
roomSize: roomSize,
}
eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global)
rule, err := eval.MatchEvent(event.PDU)
rule, err := eval.MatchEvent(event.PDU, func(roomID, senderID string) (*spec.UserID, error) {
return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
})
if err != nil {
return nil, err
}
@ -682,7 +698,7 @@ func (rse *ruleSetEvalContext) UserDisplayName() string { return rse.mem.Display
func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil }
func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, error) {
func (rse *ruleSetEvalContext) HasPowerLevel(senderID, levelKey string) (bool, error) {
req := &rsapi.QueryLatestEventsAndStateRequest{
RoomID: rse.roomID,
StateToFetch: []gomatrixserverlib.StateKeyTuple{
@ -702,7 +718,7 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err
if err != nil {
return false, err
}
return plc.UserLevel(userID) >= plc.NotificationLevel(levelKey), nil
return plc.UserLevel(senderID) >= plc.NotificationLevel(levelKey), nil
}
return true, nil
}
@ -756,6 +772,11 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes
}
default:
sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err != nil {
logger.WithError(err).Errorf("Failed to get userID for sender %s", event.SenderID())
return nil, err
}
req = pushgateway.NotifyRequest{
Notification: pushgateway.Notification{
Content: event.Content(),
@ -767,7 +788,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes
ID: event.EventID(),
RoomID: event.RoomID(),
RoomName: roomName,
Sender: event.Sender(),
Sender: sender.String(),
Type: event.Type(),
},
}

View file

@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/internal/pushrules"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/userapi/storage"
@ -44,13 +45,19 @@ func mustCreateEvent(t *testing.T, content string) *types.HeaderedEvent {
return &types.HeaderedEvent{PDU: ev}
}
type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI }
func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) {
return spec.NewUserID(senderID, true)
}
func Test_evaluatePushRules(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
consumer := OutputRoomEventConsumer{db: db}
consumer := OutputRoomEventConsumer{db: db, rsAPI: &FakeUserRoomserverAPI{}}
testCases := []struct {
name string
@ -86,7 +93,7 @@ func Test_evaluatePushRules(t *testing.T) {
},
{
name: "m.room.message highlights",
eventContent: `{"type":"m.room.message", "content": {"body": "test"} }`,
eventContent: `{"type":"m.room.message", "content": {"body": "test"}}`,
wantNotify: true,
wantAction: pushrules.NotifyAction,
wantActions: []*pushrules.Action{

View file

@ -11,6 +11,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/synctypes"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
"golang.org/x/crypto/bcrypt"
@ -87,7 +88,7 @@ func TestNotifyUserCountsAsync(t *testing.T) {
}
// Prepare pusher with our test server URL
if err := db.UpsertPusher(ctx, api.Pusher{
if err = db.UpsertPusher(ctx, api.Pusher{
Kind: api.HTTPKind,
AppID: appID,
PushKey: pushKey,
@ -99,8 +100,12 @@ func TestNotifyUserCountsAsync(t *testing.T) {
}
// Insert a dummy event
sender, err := spec.NewUserID(alice.ID, true)
if err != nil {
t.Error(err)
}
if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{
Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll),
Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender),
}); err != nil {
t.Error(err)
}