Refactor history visibility functions to be a bit neater

This commit is contained in:
Kegan Dougal 2020-09-11 15:24:58 +01:00
parent 913020e4b7
commit bb3400365b
2 changed files with 52 additions and 30 deletions

View file

@ -14,6 +14,7 @@ package auth
import ( import (
"encoding/json" "encoding/json"
"fmt"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -27,7 +28,17 @@ func IsServerAllowed(
serverCurrentlyInRoom bool, serverCurrentlyInRoom bool,
authEvents []gomatrixserverlib.Event, authEvents []gomatrixserverlib.Event,
) bool { ) bool {
historyVisibility := HistoryVisibilityForRoom(authEvents) var hisVisEvent *gomatrixserverlib.Event
for i, ae := range authEvents {
if ae.Type() == gomatrixserverlib.MRoomHistoryVisibility && ae.StateKeyEquals("") {
hisVisEvent = &authEvents[i]
break
}
}
historyVisibility, err := HistoryVisibilityForRoom(hisVisEvent)
if err != nil {
return false
}
// 1. If the history_visibility was set to world_readable, allow. // 1. If the history_visibility was set to world_readable, allow.
if historyVisibility == "world_readable" { if historyVisibility == "world_readable" {
@ -52,30 +63,31 @@ func IsServerAllowed(
return false return false
} }
func HistoryVisibilityForRoom(authEvents []gomatrixserverlib.Event) string { func HistoryVisibilityForRoom(hisVisEvent *gomatrixserverlib.Event) (string, error) {
// https://matrix.org/docs/spec/client_server/r0.6.0#id87 // https://matrix.org/docs/spec/client_server/r0.6.0#id87
// By default if no history_visibility is set, or if the value is not understood, the visibility is assumed to be shared. // By default if no history_visibility is set, or if the value is not understood, the visibility is assumed to be shared.
if hisVisEvent == nil {
return "shared", nil
}
visibility := "shared" visibility := "shared"
knownStates := []string{"invited", "joined", "shared", "world_readable"} knownStates := []string{"invited", "joined", "shared", "world_readable"}
for _, ev := range authEvents { if hisVisEvent.Type() != gomatrixserverlib.MRoomHistoryVisibility {
if ev.Type() != gomatrixserverlib.MRoomHistoryVisibility { return "", fmt.Errorf("HistoryVisibilityForRoom: passed a non history visibility event: %s", hisVisEvent.Type())
continue }
} // TODO: This should be HistoryVisibilityContent to match things like 'MemberContent'. Do this when moving to GMSL
// TODO: This should be HistoryVisibilityContent to match things like 'MemberContent'. Do this when moving to GMSL content := struct {
content := struct { HistoryVisibility string `json:"history_visibility"`
HistoryVisibility string `json:"history_visibility"` }{}
}{} if err := json.Unmarshal(hisVisEvent.Content(), &content); err != nil {
if err := json.Unmarshal(ev.Content(), &content); err != nil { return visibility, nil // value is not understood
break // value is not understood }
} for _, s := range knownStates {
for _, s := range knownStates { if s == content.HistoryVisibility {
if s == content.HistoryVisibility { visibility = s
visibility = s break
break
}
} }
} }
return visibility return visibility, nil
} }
func IsAnyUserOnServerWithMembership(serverName gomatrixserverlib.ServerName, authEvents []gomatrixserverlib.Event, wantMembership string) bool { func IsAnyUserOnServerWithMembership(serverName gomatrixserverlib.ServerName, authEvents []gomatrixserverlib.Event, wantMembership string) bool {

View file

@ -480,27 +480,37 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
// TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just // TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just
// pull all events and then filter by that table. // pull all events and then filter by that table.
func joinEventsFromHistoryVisibility( func joinEventsFromHistoryVisibility(
ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry) ([]types.Event, error) { ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry,
) ([]types.Event, error) {
var eventNIDs []types.EventNID // Extract the history visibility event
var historyVisibilityNID types.EventNID
for _, entry := range stateEntries { for _, entry := range stateEntries {
// Filter the events to retrieve to only keep the membership events
if entry.EventTypeNID == types.MRoomHistoryVisibilityNID && entry.EventStateKeyNID == types.EmptyStateKeyNID { if entry.EventTypeNID == types.MRoomHistoryVisibilityNID && entry.EventStateKeyNID == types.EmptyStateKeyNID {
eventNIDs = append(eventNIDs, entry.EventNID) historyVisibilityNID = entry.EventNID
break break
} }
} }
if historyVisibilityNID == 0 {
// Get all of the events in this state return nil, fmt.Errorf("no history visibility event for room %s", roomID)
stateEvents, err := db.Events(ctx, eventNIDs) }
stateEvents, err := db.Events(ctx, []types.EventNID{historyVisibilityNID})
if err != nil { if err != nil {
return nil, err return nil, err
} }
events := make([]gomatrixserverlib.Event, len(stateEvents)) if len(stateEvents) != 1 {
for i := range stateEvents { return nil, fmt.Errorf("failed to load history visibility event nid %d", historyVisibilityNID)
events[i] = stateEvents[i].Event }
var hisVisEvent *gomatrixserverlib.Event
for i := range stateEvents {
if stateEvents[i].Type() == gomatrixserverlib.MRoomHistoryVisibility && stateEvents[i].StateKeyEquals("") {
hisVisEvent = &stateEvents[i].Event
}
}
visibility, err := auth.HistoryVisibilityForRoom(hisVisEvent)
if err != nil {
return nil, err
} }
visibility := auth.HistoryVisibilityForRoom(events)
if visibility != "shared" { if visibility != "shared" {
logrus.Infof("ServersAtEvent history visibility not shared: %s", visibility) logrus.Infof("ServersAtEvent history visibility not shared: %s", visibility)
return nil, nil return nil, nil