From bf00a8ef7f463b37756fee78b8dc84dbe9d91f4f Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 15 Jul 2021 15:48:03 +0100 Subject: [PATCH] Wire in new code --- federationapi/routing/state.go | 53 ++++++++++++++++++++++++++++++---- 1 file changed, 47 insertions(+), 6 deletions(-) diff --git a/federationapi/routing/state.go b/federationapi/routing/state.go index 128df618..38d82e8f 100644 --- a/federationapi/routing/state.go +++ b/federationapi/routing/state.go @@ -55,17 +55,14 @@ func GetStateIDs( return *err } - state, err := getState(ctx, request, rsAPI, roomID, eventID) + state, err := getStateIDs(ctx, request, rsAPI, roomID, eventID) if err != nil { return *err } - stateEventIDs := getIDsFromEvent(state.StateEvents) - authEventIDs := getIDsFromEvent(state.AuthEvents) - return util.JSONResponse{Code: http.StatusOK, JSON: gomatrixserverlib.RespStateIDs{ - StateEventIDs: stateEventIDs, - AuthEventIDs: authEventIDs, + StateEventIDs: state.StateEventIDs, + AuthEventIDs: state.AuthEventIDs, }, } } @@ -136,6 +133,50 @@ func getState( }, nil } +func getStateIDs( + ctx context.Context, + request *gomatrixserverlib.FederationRequest, + rsAPI api.RoomserverInternalAPI, + roomID string, + eventID string, +) (*gomatrixserverlib.RespStateIDs, *util.JSONResponse) { + event, resErr := fetchEvent(ctx, rsAPI, eventID) + if resErr != nil { + return nil, resErr + } + + if event.RoomID() != roomID { + return nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: jsonerror.NotFound("event does not belong to this room")} + } + resErr = allowedToSeeEvent(ctx, request.Origin(), rsAPI, eventID) + if resErr != nil { + return nil, resErr + } + + var response api.QueryStateAndAuthChainIDsResponse + err := rsAPI.QueryStateAndAuthChainIDs( + ctx, + &api.QueryStateAndAuthChainIDsRequest{ + RoomID: roomID, + PrevEventIDs: []string{eventID}, + }, + &response, + ) + if err != nil { + resErr := util.ErrorResponse(err) + return nil, &resErr + } + + if !response.RoomExists { + return nil, &util.JSONResponse{Code: http.StatusNotFound, JSON: nil} + } + + return &gomatrixserverlib.RespStateIDs{ + StateEventIDs: response.StateEvents, + AuthEventIDs: response.AuthChainEvents, + }, nil +} + func getIDsFromEvent(events []*gomatrixserverlib.Event) []string { IDs := make([]string, len(events)) for i := range events {