diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index c02d9040..06625ad7 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -181,7 +181,9 @@ func (s *OutputRoomEventConsumer) sendEvents( // Create the transaction body. transaction, err := json.Marshal( ApplicationServiceTransaction{ - Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll), + Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }), }, ) if err != nil { @@ -233,10 +235,16 @@ func (s *appserviceState) backoffAndPause(err error) error { // // TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682 func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool { + user := "" + userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if err == nil { + user = userID.String() + } + switch { case appservice.URL == "": return false - case appservice.IsInterestedInUserID(event.Sender()): + case appservice.IsInterestedInUserID(user): return true case appservice.IsInterestedInRoomID(event.RoomID()): return true diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index c786f8cc..0c842e6a 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -215,9 +215,35 @@ func RemoveLocalAlias( alias string, rsAPI roomserverAPI.ClientRoomserverAPI, ) util.JSONResponse { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{Err: "UserID for device is invalid"}, + } + } + + roomIDReq := roomserverAPI.GetRoomIDForAliasRequest{Alias: alias} + roomIDRes := roomserverAPI.GetRoomIDForAliasResponse{} + err = rsAPI.GetRoomIDForAlias(req.Context(), &roomIDReq, &roomIDRes) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: spec.NotFound("The alias does not exist."), + } + } + + deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomIDRes.RoomID, *userID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{Err: "Could not find SenderID for this device"}, + } + } + queryReq := roomserverAPI.RemoveRoomAliasRequest{ - Alias: alias, - UserID: device.UserID, + Alias: alias, + SenderID: deviceSenderID, } var queryRes roomserverAPI.RemoveRoomAliasResponse if err := rsAPI.RemoveRoomAlias(req.Context(), &queryReq, &queryRes); err != nil { diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 88312642..e94c7748 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -76,7 +76,7 @@ func SendRedaction( // "Users may redact their own events, and any user with a power level greater than or equal // to the redact power level of the room may redact events there" // https://matrix.org/docs/spec/client_server/r0.6.1#put-matrix-client-r0-rooms-roomid-redact-eventid-txnid - allowedToRedact := ev.Sender() == device.UserID + allowedToRedact := ev.SenderID() == device.UserID // TODO: Should replace device.UserID with device...PerRoomKey if !allowedToRedact { plEvent := roomserverAPI.GetStateEvent(req.Context(), rsAPI, roomID, gomatrixserverlib.StateKeyTuple{ EventType: spec.MRoomPowerLevels, diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index 1a2e25c9..8b09f399 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -331,7 +331,9 @@ func generateSendEvent( stateEvents[i] = queryRes.StateEvents[i].PDU } provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents)) - if err = gomatrixserverlib.Allowed(e.PDU, &provider); err != nil { + if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }); err != nil { return nil, &util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden(err.Error()), // TODO: Is this error string comprehensible to the client? diff --git a/clientapi/routing/state.go b/clientapi/routing/state.go index 319f4eba..13f30899 100644 --- a/clientapi/routing/state.go +++ b/clientapi/routing/state.go @@ -140,9 +140,14 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a // use the result of the previous QueryLatestEventsAndState response // to find the state event, if provided. for _, ev := range stateRes.StateEvents { + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID()) + if err == nil && userID != nil { + sender = *userID + } stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll), + synctypes.ToClientEvent(ev, synctypes.FormatAll, sender), ) } } else { @@ -162,9 +167,14 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a } } for _, ev := range stateAfterRes.StateEvents { + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID()) + if err == nil && userID != nil { + sender = *userID + } stateEvents = append( stateEvents, - synctypes.ToClientEvent(ev, synctypes.FormatAll), + synctypes.ToClientEvent(ev, synctypes.FormatAll, sender), ) } } @@ -334,8 +344,13 @@ func OnIncomingStateTypeRequest( } } + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } stateEvent := stateEventInStateResp{ - ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll), + ClientEvent: synctypes.ToClientEvent(event, synctypes.FormatAll, sender), } var res interface{} diff --git a/cmd/resolve-state/main.go b/cmd/resolve-state/main.go index 3a4255ba..36040309 100644 --- a/cmd/resolve-state/main.go +++ b/cmd/resolve-state/main.go @@ -18,6 +18,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // This is a utility for inspecting state snapshots and running state resolution @@ -182,7 +183,9 @@ func main() { fmt.Println("Resolving state") var resolved Events resolved, err = gomatrixserverlib.ResolveConflicts( - gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, + gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return roomserverDB.GetUserIDForSender(ctx, roomID, senderID) + }, ) if err != nil { panic(err) diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index beb648a4..a97bcdea 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -36,6 +36,10 @@ type fedRoomserverAPI struct { queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error } +func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + // PerformJoin will call this function func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) { if f.inputRoomEvents == nil { diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index ed800d03..2d59d0f9 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -156,15 +156,20 @@ func (r *FederationInternalAPI) performJoinUsingServer( } joinInput := gomatrixserverlib.PerformJoinInput{ - UserID: user, - RoomID: room, - ServerName: serverName, - Content: content, - Unsigned: unsigned, - PrivateKey: r.cfg.Matrix.PrivateKey, - KeyID: r.cfg.Matrix.KeyID, - KeyRing: r.keyRing, - EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName), + UserID: user, + RoomID: room, + ServerName: serverName, + Content: content, + Unsigned: unsigned, + PrivateKey: r.cfg.Matrix.PrivateKey, + KeyID: r.cfg.Matrix.KeyID, + KeyRing: r.keyRing, + EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID, senderID string) (*spec.UserID, error) { + return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }), + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }, } response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput) @@ -358,8 +363,11 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( // authenticate the state returned (check its auth events etc) // the equivalent of CheckSendJoinResponse() + userIDProvider := func(roomID, senderID string) (*spec.UserID, error) { + return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + } authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse( - ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName), + ctx, &respPeek, respPeek.RoomVersion, r.keyRing, federatedEventProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName, userIDProvider), userIDProvider, ) if err != nil { return fmt.Errorf("error checking state returned from peeking: %w", err) @@ -509,7 +517,7 @@ func (r *FederationInternalAPI) SendInvite( event gomatrixserverlib.PDU, strippedState []gomatrixserverlib.InviteStrippedState, ) (gomatrixserverlib.PDU, error) { - _, origin, err := r.cfg.Matrix.SplitLocalID('@', event.Sender()) + inviter, err := r.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) if err != nil { return nil, err } @@ -542,7 +550,7 @@ func (r *FederationInternalAPI) SendInvite( return nil, fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err) } - inviteRes, err := r.federation.SendInviteV2(ctx, origin, destination, inviteReq) + inviteRes, err := r.federation.SendInviteV2(ctx, inviter.Domain(), destination, inviteReq) if err != nil { return nil, fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err) } @@ -635,6 +643,7 @@ func checkEventsContainCreateEvent(events []gomatrixserverlib.PDU) error { func federatedEventProvider( ctx context.Context, federation fclient.FederationClient, keyRing gomatrixserverlib.JSONVerifier, origin, server spec.ServerName, + userIDForSender spec.UserIDForSender, ) gomatrixserverlib.EventProvider { // A list of events that we have retried, if they were not included in // the auth events supplied in the send_join. @@ -684,7 +693,7 @@ func federatedEventProvider( } // Check the signatures of the event. - if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing); err != nil { + if err := gomatrixserverlib.VerifyEventSignatures(ctx, ev, keyRing, userIDForSender); err != nil { return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err) } diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 78a09d94..d792335b 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -95,6 +95,9 @@ func InviteV2( StateQuerier: rsAPI.StateQuerier(), InviteEvent: inviteReq.Event(), StrippedState: inviteReq.InviteRoomState(), + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) + }, } event, jsonErr := handleInvite(httpReq.Context(), input, rsAPI) if jsonErr != nil { @@ -185,6 +188,9 @@ func InviteV1( StateQuerier: rsAPI.StateQuerier(), InviteEvent: event, StrippedState: strippedState, + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) + }, } event, jsonErr := handleInvite(httpReq.Context(), input, rsAPI) if jsonErr != nil { diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 2980c2af..9da05918 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -99,15 +99,18 @@ func MakeJoin( } input := gomatrixserverlib.HandleMakeJoinInput{ - Context: httpReq.Context(), - UserID: userID, - RoomID: roomID, - RoomVersion: roomVersion, - RemoteVersions: remoteVersions, - RequestOrigin: request.Origin(), - LocalServerName: cfg.Matrix.ServerName, - LocalServerInRoom: res.RoomExists && res.IsInRoom, - RoomQuerier: &roomQuerier, + Context: httpReq.Context(), + UserID: userID, + RoomID: roomID, + RoomVersion: roomVersion, + RemoteVersions: remoteVersions, + RequestOrigin: request.Origin(), + LocalServerName: cfg.Matrix.ServerName, + LocalServerInRoom: res.RoomExists && res.IsInRoom, + RoomQuerier: &roomQuerier, + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) + }, BuildEventTemplate: createJoinTemplate, } response, internalErr := gomatrixserverlib.HandleMakeJoin(input) @@ -202,6 +205,9 @@ func SendJoin( PrivateKey: cfg.Matrix.PrivateKey, Verifier: keys, MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI}, + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) + }, } response, joinErr := gomatrixserverlib.HandleSendJoin(input) switch e := joinErr.(type) { diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index d7d5b599..30e99c4f 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -95,6 +95,9 @@ func MakeLeave( LocalServerName: cfg.Matrix.ServerName, LocalServerInRoom: res.RoomExists && res.IsInRoom, BuildEventTemplate: createLeaveTemplate, + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID) + }, } response, internalErr := gomatrixserverlib.HandleMakeLeave(input) @@ -213,7 +216,7 @@ func SendLeave( JSON: spec.BadJSON("No state key was provided in the leave event."), } } - if !event.StateKeyEquals(event.Sender()) { + if !event.StateKeyEquals(event.SenderID()) { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: spec.BadJSON("Event state key must match the event sender."), @@ -223,13 +226,13 @@ func SendLeave( // Check that the sender belongs to the server that is sending us // the request. By this point we've already asserted that the sender // and the state key are equal so we don't need to check both. - var serverName spec.ServerName - if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil { + sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), event.RoomID(), event.SenderID()) + if err != nil { return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden("The sender of the join is invalid"), } - } else if serverName != request.Origin() { + } else if sender.Domain() != request.Origin() { return util.JSONResponse{ Code: http.StatusForbidden, JSON: spec.Forbidden("The sender does not match the server that originated the request"), @@ -291,7 +294,7 @@ func SendLeave( } } verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{ - ServerName: serverName, + ServerName: sender.Domain(), Message: redacted, AtTS: event.OriginServerTS(), ValidityCheckingFunc: gomatrixserverlib.StrictValiditySignatureCheck, diff --git a/go.mod b/go.mod index a49dfa0c..10551f70 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230606112941-1c41e92ddf9e + github.com/matrix-org/gomatrixserverlib v0.0.0-20230606202811-a644d5d8fb66 github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.16 @@ -34,7 +34,7 @@ require ( github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.13.0 - github.com/sirupsen/logrus v1.9.2 + github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.2 github.com/tidwall/gjson v1.14.4 github.com/tidwall/sjson v1.2.5 diff --git a/go.sum b/go.sum index 79154624..3ec1c115 100644 --- a/go.sum +++ b/go.sum @@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230606112941-1c41e92ddf9e h1:I3Sfr8gZvVtLHOeI8lgc62kgLuzpMhBZ6EQOMyexXEA= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230606112941-1c41e92ddf9e/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230606202811-a644d5d8fb66 h1:6SixhMmB5Ir10xUJ6zh3A4NBxSaZCSz2s5U63Wg0eEU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230606202811-a644d5d8fb66/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A= github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= @@ -444,8 +444,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= -github.com/sirupsen/logrus v1.9.2 h1:oxx1eChJGI6Uks2ZC4W1zpLlVgqB8ner4EuQwV4Ik1Y= -github.com/sirupsen/logrus v1.9.2/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go index 7c98efd3..da33d386 100644 --- a/internal/pushrules/evaluate.go +++ b/internal/pushrules/evaluate.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) // A RuleSetEvaluator encapsulates context to evaluate an event @@ -53,7 +54,7 @@ func NewRuleSetEvaluator(ec EvaluationContext, ruleSet *RuleSet) *RuleSetEvaluat // MatchEvent returns the first matching rule. Returns nil if there // was no match rule. -func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU) (*Rule, error) { +func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) (*Rule, error) { // TODO: server-default rules have lower priority than user rules, // but they are stored together with the user rules. It's a bit // unclear what the specification (11.14.1.4 Predefined rules) @@ -68,7 +69,7 @@ func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU) (*Rule, err if rule.Default != defRules { continue } - ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec) + ok, err := ruleMatches(rule, rsat.Kind, event, rse.ec, userIDForSender) if err != nil { return nil, err } @@ -83,7 +84,7 @@ func (rse *RuleSetEvaluator) MatchEvent(event gomatrixserverlib.PDU) (*Rule, err return nil, nil } -func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec EvaluationContext) (bool, error) { +func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec EvaluationContext, userIDForSender spec.UserIDForSender) (bool, error) { if !rule.Enabled { return false, nil } @@ -113,7 +114,12 @@ func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec Evaluati return rule.RuleID == event.RoomID(), nil case SenderKind: - return rule.RuleID == event.Sender(), nil + userID := "" + sender, err := userIDForSender(event.RoomID(), event.SenderID()) + if err == nil { + userID = sender.String() + } + return rule.RuleID == userID, nil default: return false, nil @@ -143,7 +149,7 @@ func conditionMatches(cond *Condition, event gomatrixserverlib.PDU, ec Evaluatio return cmp(n), nil case SenderNotificationPermissionCondition: - return ec.HasPowerLevel(event.Sender(), cond.Key) + return ec.HasPowerLevel(event.SenderID(), cond.Key) default: return false, nil diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go index 5045a864..34c1436f 100644 --- a/internal/pushrules/evaluate_test.go +++ b/internal/pushrules/evaluate_test.go @@ -5,8 +5,13 @@ import ( "github.com/google/go-cmp/cmp" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) +func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func TestRuleSetEvaluatorMatchEvent(t *testing.T) { ev := mustEventFromJSON(t, `{}`) defaultEnabled := &Rule{ @@ -45,7 +50,7 @@ func TestRuleSetEvaluatorMatchEvent(t *testing.T) { for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { rse := NewRuleSetEvaluator(fakeEvaluationContext{3}, &tst.RuleSet) - got, err := rse.MatchEvent(tst.Event) + got, err := rse.MatchEvent(tst.Event, UserIDForSender) if err != nil { t.Fatalf("MatchEvent failed: %v", err) } @@ -82,15 +87,15 @@ func TestRuleMatches(t *testing.T) { {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("b")}, `{"content":{"body":"abc"}}`, true}, {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("d")}, `{"content":{"body":"abc"}}`, false}, - {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true}, - {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false}, + {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!room:example.com"}`, true}, + {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!otherroom:example.com"}`, false}, - {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@user@example.com"}`, true}, - {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user@example.com"}, `{"sender":"@otheruser@example.com"}`, false}, + {"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@user:example.com"}`, true}, + {"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@otheruser:example.com"}`, false}, } for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { - got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil) + got, err := ruleMatches(&tst.Rule, tst.Kind, mustEventFromJSON(t, tst.EventJSON), nil, UserIDForSender) if err != nil { t.Fatalf("ruleMatches failed: %v", err) } diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go index c9d321f2..0bbe0720 100644 --- a/internal/transactionrequest.go +++ b/internal/transactionrequest.go @@ -167,7 +167,9 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut } continue } - if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID, senderID string) (*spec.UserID, error) { + return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }); err != nil { util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) results[event.EventID()] = fclient.PDUResult{ Error: err.Error(), diff --git a/internal/transactionrequest_test.go b/internal/transactionrequest_test.go index fb30d410..6f3ce0b3 100644 --- a/internal/transactionrequest_test.go +++ b/internal/transactionrequest_test.go @@ -70,6 +70,10 @@ type FakeRsAPI struct { bannedFromRoom bool } +func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func (r *FakeRsAPI) QueryRoomVersionForRoom( ctx context.Context, roomID string, @@ -638,6 +642,10 @@ type testRoomserverAPI struct { queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse } +func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func (t *testRoomserverAPI) InputRoomEvents( ctx context.Context, request *rsAPI.InputRoomEventsRequest, diff --git a/roomserver/api/alias.go b/roomserver/api/alias.go index 37892a44..1b947540 100644 --- a/roomserver/api/alias.go +++ b/roomserver/api/alias.go @@ -62,7 +62,7 @@ type GetAliasesForRoomIDResponse struct { // RemoveRoomAliasRequest is a request to RemoveRoomAlias type RemoveRoomAliasRequest struct { // ID of the user removing the alias - UserID string `json:"user_id"` + SenderID string `json:"user_id"` // The room alias to remove Alias string `json:"alias"` } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index a37ade3a..d61a0553 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -49,6 +49,7 @@ type RoomserverInternalAPI interface { ClientRoomserverAPI UserRoomserverAPI FederationRoomserverAPI + QuerySenderIDAPI // needed to avoid chicken and egg scenario when setting up the // interdependencies between the roomserver and other input APIs @@ -75,6 +76,11 @@ type InputRoomEventsAPI interface { ) } +type QuerySenderIDAPI interface { + QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) + QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) +} + // Query the latest events and state for a room from the room server. type QueryLatestEventsAndStateAPI interface { QueryLatestEventsAndState(ctx context.Context, req *QueryLatestEventsAndStateRequest, res *QueryLatestEventsAndStateResponse) error @@ -102,6 +108,7 @@ type QueryEventsAPI interface { type SyncRoomserverAPI interface { QueryLatestEventsAndStateAPI QueryBulkStateContentAPI + QuerySenderIDAPI // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine @@ -142,6 +149,7 @@ type SyncRoomserverAPI interface { } type AppserviceRoomserverAPI interface { + QuerySenderIDAPI // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine // which room to use by querying the first events roomID. QueryEventsByID( @@ -168,6 +176,7 @@ type ClientRoomserverAPI interface { QueryLatestEventsAndStateAPI QueryBulkStateContentAPI QueryEventsAPI + QuerySenderIDAPI QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error @@ -200,6 +209,7 @@ type ClientRoomserverAPI interface { } type UserRoomserverAPI interface { + QuerySenderIDAPI QueryLatestEventsAndStateAPI KeyserverRoomserverAPI QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error @@ -213,6 +223,8 @@ type FederationRoomserverAPI interface { InputRoomEventsAPI QueryLatestEventsAndStateAPI QueryBulkStateContentAPI + QuerySenderIDAPI + // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error diff --git a/roomserver/api/query.go b/roomserver/api/query.go index e741c140..d79dcebb 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -491,10 +491,10 @@ type MembershipQuerier struct { Roomserver FederationRoomserverAPI } -func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) { +func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) { req := QueryMembershipForUserRequest{ RoomID: roomID.String(), - UserID: userID.String(), + UserID: string(senderID), } res := QueryMembershipForUserResponse{} err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res) diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index 52b90cf4..dcfb26b8 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -119,11 +119,6 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( request *api.RemoveRoomAliasRequest, response *api.RemoveRoomAliasResponse, ) error { - _, virtualHost, err := r.Cfg.Global.SplitLocalID('@', request.UserID) - if err != nil { - return err - } - roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias) if err != nil { return fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err) @@ -134,13 +129,19 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return nil } + sender, err := r.QueryUserIDForSender(ctx, roomID, request.SenderID) + if err != nil { + return fmt.Errorf("r.QueryUserIDForSender: %w", err) + } + virtualHost := sender.Domain() + response.Found = true creatorID, err := r.DB.GetCreatorIDForAlias(ctx, request.Alias) if err != nil { return fmt.Errorf("r.DB.GetCreatorIDForAlias: %w", err) } - if creatorID != request.UserID { + if creatorID != request.SenderID { var plEvent *types.HeaderedEvent var pls *gomatrixserverlib.PowerLevelContent @@ -154,7 +155,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return fmt.Errorf("plEvent.PowerLevels: %w", err) } - if pls.UserLevel(request.UserID) < pls.EventLevel(spec.MRoomCanonicalAlias, true) { + if pls.UserLevel(request.SenderID) < pls.EventLevel(spec.MRoomCanonicalAlias, true) { response.Removed = false return nil } @@ -172,9 +173,9 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return err } - sender := request.UserID - if request.UserID != ev.Sender() { - sender = ev.Sender() + sender := request.SenderID + if request.SenderID != ev.SenderID() { + sender = ev.SenderID() } _, senderDomain, err := r.Cfg.Global.SplitLocalID('@', sender) diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 7ec0892e..932ce615 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -76,7 +76,9 @@ func CheckForSoftFail( } // Check if the event is allowed. - if err = gomatrixserverlib.Allowed(event.PDU, &authEvents); err != nil { + if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return db.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { // return true, nil return true, err } diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 386083f6..764bdfe2 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -128,9 +128,13 @@ func (r *Inputer) processRoomEvent( if roomInfo == nil && !isCreateEvent { return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) } - _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender()) + sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()) if err != nil { - return fmt.Errorf("event has invalid sender %q", input.Event.Sender()) + return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err) + } + senderDomain := spec.ServerName("") + if sender != nil { + senderDomain = sender.Domain() } // If we already know about this outlier and it hasn't been rejected @@ -193,7 +197,9 @@ func (r *Inputer) processRoomEvent( serverRes.ServerNames = append(serverRes.ServerNames, input.Origin) delete(servers, input.Origin) } - if senderDomain != input.Origin && senderDomain != r.Cfg.Matrix.ServerName { + // Only perform this check if the sender mxid_mapping can be resolved. + // Don't fail processing the event if we have no mxid_maping. + if sender != nil && senderDomain != input.Origin && senderDomain != r.Cfg.Matrix.ServerName { serverRes.ServerNames = append(serverRes.ServerNames, senderDomain) delete(servers, senderDomain) } @@ -276,7 +282,9 @@ func (r *Inputer) processRoomEvent( // Check if the event is allowed by its auth events. If it isn't then // we consider the event to be "rejected" — it will still be persisted. - if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { + if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { isRejected = true rejectionErr = err logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) @@ -493,7 +501,7 @@ func (r *Inputer) processRoomEvent( func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event gomatrixserverlib.PDU) error { oldRoomID := event.RoomID() newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str - return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.Sender()) + return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.SenderID()) } // processStateBefore works out what the state is before the event and @@ -579,7 +587,9 @@ func (r *Inputer) processStateBefore( stateBeforeAuth := gomatrixserverlib.NewAuthEvents( gomatrixserverlib.ToPDUs(stateBeforeEvent), ) - if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth); rejectionErr != nil { + if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }); rejectionErr != nil { rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr) return } @@ -690,7 +700,9 @@ nextAuthEvent: // Check the signatures of the event. If this fails then we'll simply // skip it, because gomatrixserverlib.Allowed() will notice a problem // if a critical event is missing anyway. - if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing()); err != nil { + if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { continue nextAuthEvent } @@ -706,7 +718,9 @@ nextAuthEvent: } // Check if the auth event should be rejected. - err := gomatrixserverlib.Allowed(authEvent, auth) + err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }) if isRejected = err != nil; isRejected { logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) } @@ -828,11 +842,13 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r continue } + // TODO: pseudoIDs: get userID for room using state key (which is now senderID) localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey()) if err != nil { continue } + // TODO: pseudoIDs: query account by state key (which is now senderID) accountRes := &userAPI.QueryAccountByLocalpartResponse{} if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{ Localpart: localpart, diff --git a/roomserver/internal/input/input_events_test.go b/roomserver/internal/input/input_events_test.go index 56803813..0ba7d19f 100644 --- a/roomserver/internal/input/input_events_test.go +++ b/roomserver/internal/input/input_events_test.go @@ -58,7 +58,7 @@ func Test_EventAuth(t *testing.T) { } // Finally check that the event is NOT allowed - if err := gomatrixserverlib.Allowed(ev.PDU, &allower); err == nil { + if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) }); err == nil { t.Fatalf("event should not be allowed, but it was") } } diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 10486138..ac0670fc 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -473,14 +473,18 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion stateEventList = append(stateEventList, state.StateEvents...) } resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts( - roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), + roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) + }, ) if err != nil { return nil, err } // apply the current event retryAllowedState: - if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents); err != nil { + if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { switch missing := err.(type) { case gomatrixserverlib.MissingAuthEventError: h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true) @@ -565,7 +569,9 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver // will be added and duplicates will be removed. missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events)) for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { continue } missingEvents = append(missingEvents, t.cacheAndReturn(ev)) @@ -654,7 +660,9 @@ func (t *missingStateReq) lookupMissingStateViaState( authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{ StateEvents: state.GetStateEvents(), AuthEvents: state.GetAuthEvents(), - }, roomVersion, t.keys, nil) + }, roomVersion, t.keys, nil, func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) + }) if err != nil { return nil, err } @@ -889,14 +897,16 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers)) return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) } - if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys); err != nil { + if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) return nil, verifySigError{event.EventID(), err} } return t.cacheAndReturn(event), nil } -func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU) error { +func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) error { authUsingState := gomatrixserverlib.NewAuthEvents(nil) for i := range stateEvents { err := authUsingState.AddEvent(stateEvents[i]) @@ -904,7 +914,7 @@ func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverli return err } } - return gomatrixserverlib.Allowed(e, &authUsingState) + return gomatrixserverlib.Allowed(e, &authUsingState, userIDForSender) } func (t *missingStateReq) hadEvent(eventID string) { diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 575525e2..ca736cb6 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -262,13 +262,17 @@ func (r *Admin) PerformAdminDownloadState( return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err) } for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { continue } authEventMap[authEvent.EventID()] = authEvent } for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { continue } stateEventMap[stateEvent.EventID()] = stateEvent diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index fb579f03..0f743f4e 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -121,7 +121,9 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform // Specifically the test "Outbound federation can backfill events" events, err := gomatrixserverlib.RequestBackfill( ctx, req.VirtualHost, requester, - r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, + r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }, ) // Only return an error if we really couldn't get any events. if err != nil && len(events) == 0 { @@ -210,7 +212,9 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom continue } loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false) - result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents) + result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }) if err != nil { logger.WithError(err).Warn("failed to load and verify event") continue @@ -484,8 +488,8 @@ FindSuccessor: // Store the server names in a temporary map to avoid duplicates. serverSet := make(map[spec.ServerName]bool) for _, event := range memberEvents { - if _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender()); err == nil { - serverSet[senderDomain] = true + if sender, err := b.db.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()); err == nil { + serverSet[sender.Domain()] = true } } var servers []spec.ServerName diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index 41194832..897bd3a0 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -308,7 +308,9 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } } - if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil { + if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return c.DB.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") return "", &util.JSONResponse{ Code: http.StatusInternalServerError, diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 1930b5ac..e8e20ede 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -97,11 +97,12 @@ func (r *Inviter) ProcessInviteMembership( ) ([]api.OutputEvent, error) { var outputUpdates []api.OutputEvent var updater *shared.MembershipUpdater - _, domain, err := gomatrixserverlib.SplitID('@', *inviteEvent.StateKey()) + + userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey()) if err != nil { return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())} } - isTargetLocal := r.Cfg.Matrix.IsLocalServerName(domain) + isTargetLocal := r.Cfg.Matrix.IsLocalServerName(userID.Domain()) if updater, err = r.DB.MembershipUpdater(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey(), isTargetLocal, inviteEvent.Version()); err != nil { return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err) } @@ -125,9 +126,9 @@ func (r *Inviter) PerformInvite( ) error { event := req.Event - sender, err := spec.NewUserID(event.Sender(), true) + sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()) if err != nil { - return spec.InvalidParam("The user ID is invalid") + return spec.InvalidParam("The sender user ID is invalid") } if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) { return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")} @@ -155,6 +156,9 @@ func (r *Inviter) PerformInvite( StrippedState: req.InviteRoomState, MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI}, StateQuerier: &QueryState{r.DB}, + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }, } inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI) if err != nil { diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index ff4a6a1d..8c0df1c4 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -176,7 +176,7 @@ func moveLocalAliases(ctx context.Context, } for _, alias := range aliasRes.Aliases { - removeAliasReq := api.RemoveRoomAliasRequest{UserID: userID, Alias: alias} + removeAliasReq := api.RemoveRoomAliasRequest{SenderID: userID, Alias: alias} removeAliasRes := api.RemoveRoomAliasResponse{} if err = URSAPI.RemoveRoomAlias(ctx, &removeAliasReq, &removeAliasRes); err != nil { return fmt.Errorf("Failed to remove old room alias: %w", err) @@ -484,7 +484,9 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user } - if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { + if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) + }); err != nil { return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err) } @@ -567,7 +569,9 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user stateEvents[i] = queryRes.StateEvents[i].PDU } provider := gomatrixserverlib.NewAuthEvents(stateEvents) - if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider); err != nil { + if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID, senderID string) (*spec.UserID, error) { + return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) + }); err != nil { return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client? } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 6d898e8a..707e95b2 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -159,7 +159,9 @@ func (r *Queryer) QueryStateAfterEvents( } stateEvents, err = gomatrixserverlib.ResolveConflicts( - info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }, ) if err != nil { return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err) @@ -386,7 +388,12 @@ func (r *Queryer) QueryMembershipsForRoom( return fmt.Errorf("r.DB.Events: %w", err) } for _, event := range events { - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll) + sender := spec.UserID{} + userID, queryErr := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if queryErr == nil && userID != nil { + sender = *userID + } + clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender) response.JoinEvents = append(response.JoinEvents, clientEvent) } return nil @@ -435,7 +442,12 @@ func (r *Queryer) QueryMembershipsForRoom( } for _, event := range events { - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll) + sender := spec.UserID{} + userID, err := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } + clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender) response.JoinEvents = append(response.JoinEvents, clientEvent) } @@ -625,7 +637,9 @@ func (r *Queryer) QueryStateAndAuthChain( if request.ResolveState { stateEvents, err = gomatrixserverlib.ResolveConflicts( - info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }, ) if err != nil { return err @@ -960,3 +974,11 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, userID) } + +func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) { + return r.DB.GetSenderIDForUser(ctx, roomID, userID) +} + +func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) +} diff --git a/roomserver/producers/roomevent.go b/roomserver/producers/roomevent.go index febe8ddf..165304d4 100644 --- a/roomserver/producers/roomevent.go +++ b/roomserver/producers/roomevent.go @@ -60,7 +60,7 @@ func (r *RoomEventProducer) ProduceRoomEvents(roomID string, updates []api.Outpu "adds_state": len(update.NewRoomEvent.AddsStateEventIDs), "removes_state": len(update.NewRoomEvent.RemovesStateEventIDs), "send_as_server": update.NewRoomEvent.SendAsServer, - "sender": update.NewRoomEvent.Event.Sender(), + "sender": update.NewRoomEvent.Event.SenderID(), }) if update.NewRoomEvent.Event.StateKey() != nil { logger = logger.WithField("state_key", *update.NewRoomEvent.Event.StateKey()) diff --git a/roomserver/state/state.go b/roomserver/state/state.go index f38d8f96..3131cbff 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -24,6 +24,7 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" @@ -43,6 +44,7 @@ type StateResolutionStorage interface { AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) + GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) } type StateResolution struct { @@ -945,7 +947,9 @@ func (v *StateResolution) resolveConflictsV1( } // Resolve the conflicts. - resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents) + resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return v.db.GetUserIDForSender(ctx, roomID, senderID) + }) // Map from the full events back to numeric state entries. for _, resolvedEvent := range resolvedEvents { @@ -1057,6 +1061,9 @@ func (v *StateResolution) resolveConflictsV2( conflictedEvents, nonConflictedEvents, authEvents, + func(roomID, senderID string) (*spec.UserID, error) { + return v.db.GetUserIDForSender(ctx, roomID, senderID) + }, ) }() diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 7d22df00..2d007bed 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -166,6 +166,10 @@ type Database interface { GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error) // GetKnownUsers searches all users that userID knows about. GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) + // GetKnownUsers tries to obtain the current mxid for a given user. + GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) + // GetKnownUsers tries to obtain the current senderID for a given user. + GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) // GetKnownRooms returns a list of all rooms we know about. GetKnownRooms(ctx context.Context) ([]string, error) // ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room @@ -211,6 +215,7 @@ type RoomDatabase interface { GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) + GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) } type EventDatabase interface { diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index f9c889cb..105e61df 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -101,7 +101,7 @@ func (u *MembershipUpdater) Update(newMembership tables.MembershipState, event * var inserted bool // Did the query result in a membership change? var retired []string // Did we retire any updates in the process? return inserted, retired, u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.SenderID()) if err != nil { return fmt.Errorf("u.d.AssignStateKeyNID: %w", err) } diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 70672a33..73500138 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -250,3 +251,7 @@ func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error { func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) } + +func (u *RoomUpdater) GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return u.d.GetUserIDForSender(ctx, roomID, senderID) +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index cefa58a3..406d7cf1 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -988,8 +988,18 @@ func (d *EventDatabase) MaybeRedactEvent( return nil } - _, sender1, _ := gomatrixserverlib.SplitID('@', redactedEvent.Sender()) - _, sender2, _ := gomatrixserverlib.SplitID('@', redactionEvent.Sender()) + // TODO: Don't hack senderID into userID here (pseudoIDs) + sender1Domain := "" + sender1, err1 := spec.NewUserID(redactedEvent.SenderID(), true) + if err1 == nil { + sender1Domain = string(sender1.Domain()) + } + // TODO: Don't hack senderID into userID here (pseudoIDs) + sender2Domain := "" + sender2, err2 := spec.NewUserID(redactionEvent.SenderID(), true) + if err2 == nil { + sender2Domain = string(sender2.Domain()) + } var powerlevels *gomatrixserverlib.PowerLevelContent powerlevels, err = plResolver.Resolve(ctx, redactionEvent.EventID()) if err != nil { @@ -997,9 +1007,9 @@ func (d *EventDatabase) MaybeRedactEvent( } switch { - case powerlevels.UserLevel(redactionEvent.Sender()) >= powerlevels.Redact: + case powerlevels.UserLevel(redactionEvent.SenderID()) >= powerlevels.Redact: // 1. The power level of the redaction event’s sender is greater than or equal to the redact level. - case sender1 == sender2: + case sender1Domain == sender2Domain: // 2. The domain of the redaction event’s sender matches that of the original event’s sender. default: ignoreRedaction = true @@ -1514,6 +1524,16 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit) } +func (d *Database) GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + // TODO: Use real logic once DB for pseudoIDs is in place + return spec.NewUserID(senderID, true) +} + +func (d *Database) GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) { + // TODO: Use real logic once DB for pseudoIDs is in place + return userID.String(), nil +} + // GetKnownRooms returns a list of all rooms we know about. func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil) diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index f468b048..5ce3b430 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -92,9 +92,11 @@ type MSC2836EventRelationshipsResponse struct { ParsedAuthChain []gomatrixserverlib.PDU } -func toClientResponse(res *MSC2836EventRelationshipsResponse) *EventRelationshipResponse { +func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsResponse, rsAPI roomserver.RoomserverInternalAPI) *EventRelationshipResponse { out := &EventRelationshipResponse{ - Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll), + Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }), Limited: res.Limited, NextBatch: res.NextBatch, } @@ -187,7 +189,7 @@ func eventRelationshipHandler(db Database, rsAPI roomserver.RoomserverInternalAP return util.JSONResponse{ Code: 200, - JSON: toClientResponse(res), + JSON: toClientResponse(req.Context(), res, rsAPI), } } } diff --git a/setup/mscs/msc2836/msc2836_test.go b/setup/mscs/msc2836/msc2836_test.go index 2c6f63d4..c463fd72 100644 --- a/setup/mscs/msc2836/msc2836_test.go +++ b/setup/mscs/msc2836/msc2836_test.go @@ -525,6 +525,10 @@ type testRoomserverAPI struct { events map[string]*types.HeaderedEvent } +func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func (r *testRoomserverAPI) QueryEventsByID(ctx context.Context, req *roomserver.QueryEventsByIDRequest, res *roomserver.QueryEventsByIDResponse) error { for _, eventID := range req.EventIDs { ev := r.events[eventID] diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index 291e0f3b..f380d3d4 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -730,7 +730,7 @@ func stripped(ev gomatrixserverlib.PDU) *fclient.MSC2946StrippedEvent { Type: ev.Type(), StateKey: *ev.StateKey(), Content: ev.Content(), - Sender: ev.Sender(), + Sender: ev.SenderID(), OriginServerTS: ev.OriginServerTS(), } } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 56285dbf..c0836465 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -523,7 +523,7 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *rstypes.HeaderedEvent) prev := types.PrevEventRef{ PrevContent: prevEvent.Content(), ReplacesState: prevEvent.EventID(), - PrevSender: prevEvent.Sender(), + PrevSender: prevEvent.SenderID(), } event.PDU, err = event.SetUnsigned(prev) diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index ac17d39d..27e99a35 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -193,14 +193,20 @@ func Context( } } - eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll) - eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll) + eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) + eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) newState := state if filter.LazyLoadMembers { allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...) allEvents = append(allEvents, &requestedEvent) - evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll) + evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache) if err != nil { logrus.WithError(err).Error("unable to load membership events") @@ -211,12 +217,19 @@ func Context( } } - ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll) + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(ctx, requestedEvent.RoomID(), requestedEvent.SenderID()) + if err == nil && userID != nil { + sender = *userID + } + ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender) response := ContextRespsonse{ Event: &ev, EventsAfter: eventsAfterClient, EventsBefore: eventsBeforeClient, - State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll), + State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }), } if len(response.State) > filter.Limit { diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go index 0d3d412f..63df7e83 100644 --- a/syncapi/routing/getevent.go +++ b/syncapi/routing/getevent.go @@ -101,8 +101,13 @@ func GetEvent( } } + sender := spec.UserID{} + senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), roomID, events[0].SenderID()) + if err == nil && senderUserID != nil { + sender = *senderUserID + } return util.JSONResponse{ Code: http.StatusOK, - JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll), + JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender), } } diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index 7d2e137d..9c2319dd 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -144,7 +144,7 @@ func GetMemberships( JSON: spec.InternalServerError{}, } } - res.Joined[ev.Sender()] = joinedMember(content) + res.Joined[ev.SenderID()] = joinedMember(content) } return util.JSONResponse{ Code: http.StatusOK, @@ -153,6 +153,8 @@ func GetMemberships( } return util.JSONResponse{ Code: http.StatusOK, - JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll)}, + JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) + })}, } } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index aeaec699..879739d0 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -241,7 +241,7 @@ func OnIncomingMessagesRequest( device: device, } - clientEvents, start, end, err := mReq.retrieveEvents() + clientEvents, start, end, err := mReq.retrieveEvents(req.Context(), rsAPI) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("mreq.retrieveEvents failed") return util.JSONResponse{ @@ -273,7 +273,9 @@ func OnIncomingMessagesRequest( JSON: spec.InternalServerError{}, } } - res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll)...) + res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) + })...) } // If we didn't return any events, set the end to an empty string, so it will be omitted @@ -310,7 +312,7 @@ func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api. // homeserver in the room for older events. // Returns an error if there was an issue talking to the database or with the // remote homeserver. -func (r *messagesReq) retrieveEvents() ( +func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserverAPI) ( clientEvents []synctypes.ClientEvent, start, end types.TopologyToken, err error, ) { @@ -383,7 +385,9 @@ func (r *messagesReq) retrieveEvents() ( "events_before": len(events), "events_after": len(filteredEvents), }).Debug("applied history visibility (messages)") - return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll), start, end, err + return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }), start, end, err } func (r *messagesReq) getStartEnd(events []*rstypes.HeaderedEvent) (start, end types.TopologyToken, err error) { diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go index 8374bf5b..f21c684c 100644 --- a/syncapi/routing/relations.go +++ b/syncapi/routing/relations.go @@ -114,9 +114,14 @@ func Relations( // type if it was specified. res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents)) for _, event := range filteredEvents { + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } res.Chunk = append( res.Chunk, - synctypes.ToClientEvent(event.PDU, synctypes.FormatAll), + synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender), ) } diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 9ad0c047..8542c0b7 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -171,7 +171,7 @@ func Setup( nb := req.FormValue("next_batch") nextBatch = &nb } - return Search(req, device, syncDB, fts, nextBatch) + return Search(req, device, syncDB, fts, nextBatch, rsAPI) }), ).Methods(http.MethodPost, http.MethodOptions) diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index b7191873..9cf3eabe 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -31,6 +31,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/sqlutil" + roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" @@ -38,7 +39,7 @@ import ( ) // nolint:gocyclo -func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts fulltext.Indexer, from *string) util.JSONResponse { +func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts fulltext.Indexer, from *string, rsAPI roomserverAPI.SyncRoomserverAPI) util.JSONResponse { start := time.Now() var ( searchReq SearchRequest @@ -204,11 +205,17 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts profileInfos := make(map[string]ProfileInfoResponse) for _, ev := range append(eventsBefore, eventsAfter...) { - profile, ok := knownUsersProfiles[event.Sender()] + userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID()) + if queryErr != nil { + logrus.WithError(queryErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") + continue + } + + profile, ok := knownUsersProfiles[userID.String()] if !ok { - stateEvent, err := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, ev.Sender()) - if err != nil { - logrus.WithError(err).WithField("user_id", event.Sender()).Warn("failed to query userprofile") + stateEvent, stateErr := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, ev.SenderID()) + if stateErr != nil { + logrus.WithError(stateErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") continue } if stateEvent == nil { @@ -218,21 +225,30 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts AvatarURL: gjson.GetBytes(stateEvent.Content(), "avatar_url").Str, DisplayName: gjson.GetBytes(stateEvent.Content(), "displayname").Str, } - knownUsersProfiles[event.Sender()] = profile + knownUsersProfiles[userID.String()] = profile } - profileInfos[ev.Sender()] = profile + profileInfos[userID.String()] = profile } + sender := spec.UserID{} + userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } results = append(results, Result{ Context: SearchContextResponse{ - Start: startToken.String(), - End: endToken.String(), - EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync), - EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync), - ProfileInfo: profileInfos, + Start: startToken.String(), + End: endToken.String(), + EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) + }), + EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) + }), + ProfileInfo: profileInfos, }, Rank: eventScore[event.EventID()].Score, - Result: synctypes.ToClientEvent(event, synctypes.FormatAll), + Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender), }) roomGroup := groups[event.RoomID()] roomGroup.Results = append(roomGroup.Results, event.EventID()) @@ -247,7 +263,9 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts JSON: spec.InternalServerError{}, } } - stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync) + stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) + }) } } diff --git a/syncapi/routing/search_test.go b/syncapi/routing/search_test.go index 1cc95a87..b36be823 100644 --- a/syncapi/routing/search_test.go +++ b/syncapi/routing/search_test.go @@ -2,6 +2,7 @@ package routing import ( "bytes" + "context" "encoding/json" "net/http" "net/http/httptest" @@ -9,6 +10,7 @@ import ( "github.com/matrix-org/dendrite/internal/fulltext" "github.com/matrix-org/dendrite/internal/sqlutil" + rsapi "github.com/matrix-org/dendrite/roomserver/api" rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" @@ -21,6 +23,12 @@ import ( "github.com/stretchr/testify/assert" ) +type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI } + +func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func TestSearch(t *testing.T) { alice := test.NewUser(t) aliceDevice := userapi.Device{UserID: alice.ID} @@ -247,7 +255,7 @@ func TestSearch(t *testing.T) { assert.NoError(t, err) req := httptest.NewRequest(http.MethodPost, "/", reqBody) - res := Search(req, tc.device, db, fts, tc.from) + res := Search(req, tc.device, db, fts, tc.from, &FakeSyncRoomserverAPI{}) if !tc.wantOK && !res.Is2xx() { return } diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 0cc96373..bfe5e9bd 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -343,7 +343,7 @@ func (s *currentRoomStateStatements) UpsertRoomState( event.RoomID(), event.EventID(), event.Type(), - event.Sender(), + event.SenderID(), containsURL, *event.StateKey(), headeredJSON, diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 3aadbccf..e068afab 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -407,7 +407,7 @@ func (s *outputRoomEventsStatements) InsertEvent( event.EventID(), headeredJSON, event.Type(), - event.Sender(), + event.SenderID(), containsURL, pq.StringArray(addState), pq.StringArray(removeState), diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index ecfd418f..17a6a69c 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -195,7 +195,21 @@ func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.Strea for i := 0; i < len(in); i++ { out[i] = in[i].HeaderedEvent if device != nil && in[i].TransactionID != nil { - if device.UserID == in[i].Sender() && device.SessionID == in[i].TransactionID.SessionID { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Failed to add transaction ID to event") + continue + } + deviceSenderID, err := d.getSenderIDForUser(in[i].RoomID(), *userID) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Failed to add transaction ID to event") + continue + } + if deviceSenderID == in[i].SenderID() && device.SessionID == in[i].TransactionID.SessionID { err := out[i].SetUnsignedField( "transaction_id", in[i].TransactionID.TransactionID, ) @@ -210,6 +224,11 @@ func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.Strea return out } +func (d *Database) getSenderIDForUser(roomID string, userID spec.UserID) (string, error) { // nolint + // TODO: Repalce with actual logic for pseudoIDs + return userID.String(), nil +} + // handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of // the events listed in the event's 'prev_events'. This function also updates the backwards extremities table // to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 1b8632eb..e432e483 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -342,7 +342,7 @@ func (s *currentRoomStateStatements) UpsertRoomState( event.RoomID(), event.EventID(), event.Type(), - event.Sender(), + event.SenderID(), containsURL, *event.StateKey(), headeredJSON, diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index d63e7606..5a47aec4 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -348,7 +348,7 @@ func (s *outputRoomEventsStatements) InsertEvent( event.EventID(), headeredJSON, event.Type(), - event.Sender(), + event.SenderID(), containsURL, string(addStateJSON), string(removeStateJSON), diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go index becd863a..a8b0a7b6 100644 --- a/syncapi/streams/stream_invite.go +++ b/syncapi/streams/stream_invite.go @@ -10,6 +10,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" @@ -17,6 +18,7 @@ import ( type InviteStreamProvider struct { DefaultStreamProvider + rsAPI api.SyncRoomserverAPI } func (p *InviteStreamProvider) Setup( @@ -62,11 +64,17 @@ func (p *InviteStreamProvider) IncrementalSync( } for roomID, inviteEvent := range invites { + user := spec.UserID{} + sender, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), inviteEvent.SenderID()) + if err == nil && sender != nil { + user = *sender + } + // skip ignored user events - if _, ok := req.IgnoredUsers.List[inviteEvent.Sender()]; ok { + if _, ok := req.IgnoredUsers.List[user.String()]; ok { continue } - ir := types.NewInviteResponse(inviteEvent) + ir := types.NewInviteResponse(inviteEvent, user) req.Response.Rooms.Invite[roomID] = ir } diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 0ea48a9d..8f83a089 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -376,20 +376,28 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } } jr.Timeline.PrevBatch = &prevBatch - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync) + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync) + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) req.Response.Rooms.Join[delta.RoomID] = jr case spec.Peek: jr := types.NewJoinResponse() jr.Timeline.PrevBatch = &prevBatch // TODO: Apply history visibility on peeked rooms - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync) + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) jr.Timeline.Limited = limited - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync) + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) req.Response.Rooms.Peek[delta.RoomID] = jr case spec.Leave: @@ -398,11 +406,15 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( case spec.Ban: lr := types.NewLeaveResponse() lr.Timeline.PrevBatch = &prevBatch - lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync) + lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. lr.Timeline.Limited = limited && len(events) == len(recentEvents) - lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync) + lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) req.Response.Rooms.Leave[delta.RoomID] = lr } @@ -425,7 +437,7 @@ func applyHistoryVisibilityFilter( for _, ev := range recentEvents { if ev.StateKey() != nil { stateTypes = append(stateTypes, ev.Type()) - senders = append(senders, ev.Sender()) + senders = append(senders, ev.SenderID()) } } @@ -552,11 +564,15 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( } jr.Timeline.PrevBatch = prevBatch - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync) + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = limited && len(events) == len(recentEvents) - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync) + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) return jr, nil } @@ -577,8 +593,8 @@ func (p *PDUStreamProvider) lazyLoadMembers( // Add all users the client doesn't know about yet to a list for _, event := range timelineEvents { // Membership is not yet cached, add it to the list - if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, event.Sender()); !ok { - timelineUsers[event.Sender()] = struct{}{} + if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, event.SenderID()); !ok { + timelineUsers[event.SenderID()] = struct{}{} } } // Preallocate with the same amount, even if it will end up with fewer values diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index a35491ac..f25bc978 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -45,6 +45,7 @@ func NewSyncStreamProviders( }, InviteStreamProvider: &InviteStreamProvider{ DefaultStreamProvider: DefaultStreamProvider{DB: d}, + rsAPI: rsAPI, }, SendToDeviceStreamProvider: &SendToDeviceStreamProvider{ DefaultStreamProvider: DefaultStreamProvider{DB: d}, diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index bc766e66..78c857ab 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -40,6 +40,10 @@ type syncRoomserverAPI struct { rooms []*test.Room } +func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func (s *syncRoomserverAPI) QueryLatestEventsAndState(ctx context.Context, req *rsapi.QueryLatestEventsAndStateRequest, res *rsapi.QueryLatestEventsAndStateResponse) error { var room *test.Room for _, r := range s.rooms { diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go index c722fe60..66fb1d01 100644 --- a/syncapi/synctypes/clientevent.go +++ b/syncapi/synctypes/clientevent.go @@ -44,22 +44,27 @@ type ClientEvent struct { } // ToClientEvents converts server events to client events. -func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat) []ClientEvent { +func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat, userIDForSender spec.UserIDForSender) []ClientEvent { evs := make([]ClientEvent, 0, len(serverEvs)) for _, se := range serverEvs { if se == nil { continue // TODO: shouldn't happen? } - evs = append(evs, ToClientEvent(se, format)) + sender := spec.UserID{} + userID, err := userIDForSender(se.RoomID(), se.SenderID()) + if err == nil && userID != nil { + sender = *userID + } + evs = append(evs, ToClientEvent(se, format, sender)) } return evs } // ToClientEvent converts a single server event to a client event. -func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat) ClientEvent { +func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID) ClientEvent { ce := ClientEvent{ Content: spec.RawJSON(se.Content()), - Sender: se.Sender(), + Sender: sender.String(), Type: se.Type(), StateKey: se.StateKey(), Unsigned: spec.RawJSON(se.Unsigned()), diff --git a/syncapi/synctypes/clientevent_test.go b/syncapi/synctypes/clientevent_test.go index b914e64f..34179508 100644 --- a/syncapi/synctypes/clientevent_test.go +++ b/syncapi/synctypes/clientevent_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) func TestToClientEvent(t *testing.T) { // nolint: gocyclo @@ -43,7 +44,11 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo if err != nil { t.Fatalf("failed to create Event: %s", err) } - ce := ToClientEvent(ev, FormatAll) + userID, err := spec.NewUserID("@test:localhost", true) + if err != nil { + t.Fatalf("failed to create userID: %s", err) + } + ce := ToClientEvent(ev, FormatAll, *userID) if ce.EventID != ev.EventID() { t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID) } @@ -62,8 +67,8 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo if !bytes.Equal(ce.Unsigned, ev.Unsigned()) { t.Errorf("ClientEvent.Unsigned: wanted %s, got %s", string(ev.Unsigned()), string(ce.Unsigned)) } - if ce.Sender != ev.Sender() { - t.Errorf("ClientEvent.Sender: wanted %s, got %s", ev.Sender(), ce.Sender) + if ce.Sender != userID.String() { + t.Errorf("ClientEvent.Sender: wanted %s, got %s", userID.String(), ce.Sender) } j, err := json.Marshal(ce) if err != nil { @@ -98,7 +103,11 @@ func TestToClientFormatSync(t *testing.T) { if err != nil { t.Fatalf("failed to create Event: %s", err) } - ce := ToClientEvent(ev, FormatSync) + userID, err := spec.NewUserID("@test:localhost", true) + if err != nil { + t.Fatalf("failed to create userID: %s", err) + } + ce := ToClientEvent(ev, FormatSync, *userID) if ce.RoomID != "" { t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID) } diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 22c27fea..526a120d 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -539,7 +539,7 @@ type InviteResponse struct { } // NewInviteResponse creates an empty response with initialised arrays. -func NewInviteResponse(event *types.HeaderedEvent) *InviteResponse { +func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteResponse { res := InviteResponse{} res.InviteState.Events = []json.RawMessage{} @@ -552,7 +552,7 @@ func NewInviteResponse(event *types.HeaderedEvent) *InviteResponse { // Then we'll see if we can create a partial of the invite event itself. // This is needed for clients to work out *who* sent the invite. - inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync) + inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID) inviteEvent.Unsigned = nil if ev, err := json.Marshal(inviteEvent); err == nil { res.InviteState.Events = append(res.InviteState.Events, ev) diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index 8e0448fe..a79ce541 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -8,8 +8,13 @@ import ( "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" ) +func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func TestSyncTokens(t *testing.T) { shouldPass := map[string]string{ "s4_0_0_0_0_0_0_0_3": StreamingToken{4, 0, 0, 0, 0, 0, 0, 0, 3}.String(), @@ -56,7 +61,12 @@ func TestNewInviteResponse(t *testing.T) { t.Fatal(err) } - res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}) + sender, err := spec.NewUserID("@neilalexander:matrix.org", true) + if err != nil { + t.Fatal(err) + } + + res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender) j, err := json.Marshal(res) if err != nil { t.Fatal(err) diff --git a/test/room.go b/test/room.go index 852e3153..4cdb73aa 100644 --- a/test/room.go +++ b/test/room.go @@ -39,6 +39,10 @@ var ( roomIDCounter = int64(0) ) +func UserIDForSender(roomID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + type Room struct { ID string Version gomatrixserverlib.RoomVersion @@ -195,7 +199,7 @@ func (r *Room) CreateEvent(t *testing.T, creator *User, eventType string, conten if err != nil { t.Fatalf("CreateEvent[%s]: failed to build event: %s", eventType, err) } - if err = gomatrixserverlib.Allowed(ev, &r.authEvents); err != nil { + if err = gomatrixserverlib.Allowed(ev, &r.authEvents, UserIDForSender); err != nil { t.Fatalf("CreateEvent[%s]: failed to verify event was allowed: %s", eventType, err) } headeredEvent := &rstypes.HeaderedEvent{PDU: ev} diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 3cfdc0ce..c025deee 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -108,7 +108,7 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms } if s.cfg.Matrix.ReportStats.Enabled { - go s.storeMessageStats(ctx, event.Type(), event.Sender(), event.RoomID()) + go s.storeMessageStats(ctx, event.Type(), event.SenderID(), event.RoomID()) } log.WithFields(log.Fields{ @@ -301,7 +301,12 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst switch { case event.Type() == spec.MRoomMember: - cevent := synctypes.ToClientEvent(event, synctypes.FormatAll) + sender := spec.UserID{} + userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if queryErr == nil && userID != nil { + sender = *userID + } + cevent := synctypes.ToClientEvent(event, synctypes.FormatAll, sender) var member *localMembership member, err = newLocalMembership(&cevent) if err != nil { @@ -529,12 +534,17 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype return fmt.Errorf("s.localPushDevices: %w", err) } + sender := spec.UserID{} + userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } n := &api.Notification{ Actions: actions, // UNSPEC: the spec doesn't say this is a ClientEvent, but the // fields seem to match. room_id should be missing, which // matches the behaviour of FormatSync. - Event: synctypes.ToClientEvent(event, synctypes.FormatSync), + Event: synctypes.ToClientEvent(event, synctypes.FormatSync, sender), // TODO: this is per-device, but it's not part of the primary // key. So inserting one notification per profile tag doesn't // make sense. What is this supposed to be? Sytests require it @@ -615,7 +625,12 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype // evaluatePushRules fetches and evaluates the push rules of a local // user. Returns actions (including dont_notify). func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) { - if event.Sender() == mem.UserID { + user := "" + sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if err == nil { + user = sender.String() + } + if user == mem.UserID { // SPEC: Homeservers MUST NOT notify the Push Gateway for // events that the user has sent themselves. return nil, nil @@ -632,9 +647,8 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * if err != nil { return nil, err } - sender := event.Sender() - if _, ok := ignored.List[sender]; ok { - return nil, fmt.Errorf("user %s is ignored", sender) + if _, ok := ignored.List[sender.String()]; ok { + return nil, fmt.Errorf("user %s is ignored", sender.String()) } } ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart, mem.Domain) @@ -650,7 +664,9 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * roomSize: roomSize, } eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global) - rule, err := eval.MatchEvent(event.PDU) + rule, err := eval.MatchEvent(event.PDU, func(roomID, senderID string) (*spec.UserID, error) { + return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) + }) if err != nil { return nil, err } @@ -682,7 +698,7 @@ func (rse *ruleSetEvalContext) UserDisplayName() string { return rse.mem.Display func (rse *ruleSetEvalContext) RoomMemberCount() (int, error) { return rse.roomSize, nil } -func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, error) { +func (rse *ruleSetEvalContext) HasPowerLevel(senderID, levelKey string) (bool, error) { req := &rsapi.QueryLatestEventsAndStateRequest{ RoomID: rse.roomID, StateToFetch: []gomatrixserverlib.StateKeyTuple{ @@ -702,7 +718,7 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err if err != nil { return false, err } - return plc.UserLevel(userID) >= plc.NotificationLevel(levelKey), nil + return plc.UserLevel(senderID) >= plc.NotificationLevel(levelKey), nil } return true, nil } @@ -756,6 +772,11 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes } default: + sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if err != nil { + logger.WithError(err).Errorf("Failed to get userID for sender %s", event.SenderID()) + return nil, err + } req = pushgateway.NotifyRequest{ Notification: pushgateway.Notification{ Content: event.Content(), @@ -767,7 +788,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes ID: event.EventID(), RoomID: event.RoomID(), RoomName: roomName, - Sender: event.Sender(), + Sender: sender.String(), Type: event.Type(), }, } diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 53977206..899a5aaf 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/internal/pushrules" + rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi/storage" @@ -44,13 +45,19 @@ func mustCreateEvent(t *testing.T, content string) *types.HeaderedEvent { return &types.HeaderedEvent{PDU: ev} } +type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI } + +func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return spec.NewUserID(senderID, true) +} + func Test_evaluatePushRules(t *testing.T) { ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() - consumer := OutputRoomEventConsumer{db: db} + consumer := OutputRoomEventConsumer{db: db, rsAPI: &FakeUserRoomserverAPI{}} testCases := []struct { name string @@ -86,7 +93,7 @@ func Test_evaluatePushRules(t *testing.T) { }, { name: "m.room.message highlights", - eventContent: `{"type":"m.room.message", "content": {"body": "test"} }`, + eventContent: `{"type":"m.room.message", "content": {"body": "test"}}`, wantNotify: true, wantAction: pushrules.NotifyAction, wantActions: []*pushrules.Action{ diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go index e1c88d47..27dd373c 100644 --- a/userapi/util/notify_test.go +++ b/userapi/util/notify_test.go @@ -11,6 +11,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "golang.org/x/crypto/bcrypt" @@ -87,7 +88,7 @@ func TestNotifyUserCountsAsync(t *testing.T) { } // Prepare pusher with our test server URL - if err := db.UpsertPusher(ctx, api.Pusher{ + if err = db.UpsertPusher(ctx, api.Pusher{ Kind: api.HTTPKind, AppID: appID, PushKey: pushKey, @@ -99,8 +100,12 @@ func TestNotifyUserCountsAsync(t *testing.T) { } // Insert a dummy event + sender, err := spec.NewUserID(alice.ID, true) + if err != nil { + t.Error(err) + } if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ - Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll), + Event: synctypes.ToClientEvent(dummyEvent, synctypes.FormatAll, *sender), }); err != nil { t.Error(err) }