From 498fbc277b503cca3bcebeddfed99d9caa7ccb3d Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 11 Nov 2021 12:09:49 +0000 Subject: [PATCH] Try to process local restricted joins --- roomserver/internal/input/input.go | 3 +- roomserver/internal/perform/perform_join.go | 95 ++++++++++++++++++++- 2 files changed, 95 insertions(+), 3 deletions(-) diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index de40e133..a4d43723 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -214,8 +214,7 @@ func (r *Inputer) InputRoomEvents( for _, task := range tasks { if task.err != nil { response.ErrMsg = task.err.Error() - _, rejected := task.err.(*gomatrixserverlib.NotAllowed) - response.NotAllowed = rejected + _, response.NotAllowed = task.err.(*gomatrixserverlib.NotAllowed) return } } diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 772c9d7d..7d25897a 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -16,6 +16,7 @@ package perform import ( "context" + "encoding/json" "errors" "fmt" "strings" @@ -149,7 +150,7 @@ func (r *Joiner) performJoinRoomByAlias( return r.performJoinRoomByID(ctx, req) } -// TODO: Break this function up a bit +// nolint:gocyclo func (r *Joiner) performJoinRoomByID( ctx context.Context, req *rsAPI.PerformJoinRequest, @@ -240,6 +241,23 @@ func (r *Joiner) performJoinRoomByID( return req.RoomIDOrAlias, joinedVia, err } + // Check if the room is a restricted room. If so, update the event + // builder content. + if restricted, roomIDs, rerr := r.checkIfRestrictedJoin(ctx, req); rerr != nil { + return "", "", fmt.Errorf("r.performRestrictedJoinChecks: %w", rerr) + } else if restricted { + success := false + for _, roomID := range roomIDs { + if err = r.attemptRestrictedJoinUsingRoomID(ctx, req, roomID, &eb); err != nil { + continue + } + success = true + } + if !success { + return "", "", fmt.Errorf("restricted join failed") + } + } + // Try to construct an actual join event from the template. // If this succeeds then it is a sign that the room already exists // locally on the homeserver. @@ -318,6 +336,81 @@ func (r *Joiner) performJoinRoomByID( return req.RoomIDOrAlias, r.Cfg.Matrix.ServerName, nil } +func (r *Joiner) checkIfRestrictedJoin( + ctx context.Context, + req *rsAPI.PerformJoinRequest, +) (bool, []string, error) { + // Look up the join rules event for the room, so we can check if it is a + // restricted room or not. + joinRuleEvent, err := r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, gomatrixserverlib.MRoomJoinRules, "") + if err != nil { + return false, nil, fmt.Errorf("r.DB.GetStateEvent: %w", err) + } + joinRuleContent := &gomatrixserverlib.JoinRuleContent{ + JoinRule: gomatrixserverlib.Public, + } + if err = json.Unmarshal(joinRuleEvent.Content(), &joinRuleContent); err != nil { + return false, nil, fmt.Errorf("json.Unmarshal: %w", err) + } + roomIDs := make([]string, 0, len(joinRuleContent.Allow)) + for _, allowed := range joinRuleContent.Allow { + if allowed.Type != gomatrixserverlib.MRoomMembership { + continue + } + roomIDs = append(roomIDs, allowed.RoomID) + } + return joinRuleContent.JoinRule != gomatrixserverlib.Restricted, roomIDs, nil +} + +func (r *Joiner) attemptRestrictedJoinUsingRoomID( + ctx context.Context, + req *rsAPI.PerformJoinRequest, + roomID string, + eb *gomatrixserverlib.EventBuilder, +) error { + roomInfo, err := r.DB.RoomInfo(ctx, roomID) + if err != nil { + return fmt.Errorf("r.DB.RoomInfo: %w", err) + } + powerLevelEvent, err := r.DB.GetStateEvent(ctx, roomID, gomatrixserverlib.MRoomPowerLevels, "") + if err != nil { + return fmt.Errorf("r.DB.GetStateEvent: %w", err) + } + powerLevels, err := powerLevelEvent.PowerLevels() + if err != nil { + return fmt.Errorf("powerLevelEvent.PowerLevels: %w", err) + } + eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true) + if err != nil { + return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err) + } + events, err := r.DB.Events(ctx, eventNIDs) + if err != nil { + return fmt.Errorf("r.DB.Events: %w", err) + } + for _, event := range events { + userID := *event.StateKey() + if userID == req.UserID { + continue + } + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil || domain != r.ServerName { + continue + } + if powerLevels.UserLevel(userID) < powerLevels.Invite { + continue + } + if err := eb.SetContent(map[string]string{ + "membership": gomatrixserverlib.Join, + "join_authorised_via_users_server": userID, + }); err != nil { + return fmt.Errorf("eb.SetContent: %w", err) + } + return nil + } + return fmt.Errorf("no suitable users found in the room") +} + func (r *Joiner) performFederatedJoinRoomByID( ctx context.Context, req *rsAPI.PerformJoinRequest,