[pseudoID] More pseudo ID fixes (#3167)

Signed-off-by: `Sam Wedgwood <sam@wedgwood.dev>`
This commit is contained in:
Sam Wedgwood 2023-08-15 12:37:04 +01:00 committed by GitHub
parent fa6c7ba456
commit 9a12420428
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 472 additions and 237 deletions

View file

@ -161,12 +161,12 @@ func (r *Admin) PerformAdminEvacuateUser(
return nil, fmt.Errorf("can only evacuate local users using this endpoint")
}
roomIDs, err := r.DB.GetRoomsByMembership(ctx, userID, spec.Join)
roomIDs, err := r.DB.GetRoomsByMembership(ctx, *fullUserID, spec.Join)
if err != nil {
return nil, err
}
inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, userID, spec.Invite)
inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, *fullUserID, spec.Invite)
if err != nil && err != sql.ErrNoRows {
return nil, err
}

View file

@ -230,6 +230,33 @@ func (r *Queryer) QueryMembershipForSenderID(
senderID spec.SenderID,
response *api.QueryMembershipForUserResponse,
) error {
return r.queryMembershipForOptionalSenderID(ctx, roomID, &senderID, response)
}
// QueryMembershipForUser implements api.RoomserverInternalAPI
func (r *Queryer) QueryMembershipForUser(
ctx context.Context,
request *api.QueryMembershipForUserRequest,
response *api.QueryMembershipForUserResponse,
) error {
roomID, err := spec.NewRoomID(request.RoomID)
if err != nil {
return err
}
senderID, err := r.QuerySenderIDForUser(ctx, *roomID, request.UserID)
if err != nil {
return err
}
return r.queryMembershipForOptionalSenderID(ctx, *roomID, senderID, response)
}
// Query membership information for provided sender ID and room ID
//
// If sender ID is nil, then act as if the provided sender is not a member of the room.
func (r *Queryer) queryMembershipForOptionalSenderID(ctx context.Context, roomID spec.RoomID, senderID *spec.SenderID, response *api.QueryMembershipForUserResponse) error {
response.SenderID = senderID
info, err := r.DB.RoomInfo(ctx, roomID.String())
if err != nil {
return err
@ -240,7 +267,11 @@ func (r *Queryer) QueryMembershipForSenderID(
}
response.RoomExists = true
membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, senderID)
if senderID == nil {
return nil
}
membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, *senderID)
if err != nil {
return err
}
@ -268,70 +299,55 @@ func (r *Queryer) QueryMembershipForSenderID(
return err
}
// QueryMembershipForUser implements api.RoomserverInternalAPI
func (r *Queryer) QueryMembershipForUser(
ctx context.Context,
request *api.QueryMembershipForUserRequest,
response *api.QueryMembershipForUserResponse,
) error {
roomID, err := spec.NewRoomID(request.RoomID)
if err != nil {
return err
}
senderID, err := r.QuerySenderIDForUser(ctx, *roomID, request.UserID)
if err != nil {
return err
}
return r.QueryMembershipForSenderID(ctx, *roomID, *senderID, response)
}
// QueryMembershipAtEvent returns the known memberships at a given event.
// If the state before an event is not known, an empty list will be returned
// for that event instead.
//
// Returned map from eventID to membership event. Events that
// do not have known state will return a nil event, resulting in a "leave" membership
// when calculating history visibility.
func (r *Queryer) QueryMembershipAtEvent(
ctx context.Context,
request *api.QueryMembershipAtEventRequest,
response *api.QueryMembershipAtEventResponse,
) error {
response.Membership = make(map[string]*types.HeaderedEvent)
info, err := r.DB.RoomInfo(ctx, request.RoomID)
roomID spec.RoomID,
eventIDs []string,
senderID spec.SenderID,
) (map[string]*types.HeaderedEvent, error) {
info, err := r.DB.RoomInfo(ctx, roomID.String())
if err != nil {
return fmt.Errorf("unable to get roomInfo: %w", err)
return nil, fmt.Errorf("unable to get roomInfo: %w", err)
}
if info == nil {
return fmt.Errorf("no roomInfo found")
return nil, fmt.Errorf("no roomInfo found")
}
// get the users stateKeyNID
stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.UserID})
stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{string(senderID)})
if err != nil {
return fmt.Errorf("unable to get stateKeyNIDs for %s: %w", request.UserID, err)
return nil, fmt.Errorf("unable to get stateKeyNIDs for %s: %w", senderID, err)
}
if _, ok := stateKeyNIDs[request.UserID]; !ok {
return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID)
if _, ok := stateKeyNIDs[string(senderID)]; !ok {
return nil, fmt.Errorf("requested stateKeyNID for %s was not found", senderID)
}
response.Membership, err = r.DB.GetMembershipForHistoryVisibility(ctx, stateKeyNIDs[request.UserID], info, request.EventIDs...)
eventIDMembershipMap, err := r.DB.GetMembershipForHistoryVisibility(ctx, stateKeyNIDs[string(senderID)], info, eventIDs...)
switch err {
case nil:
return nil
return eventIDMembershipMap, nil
case tables.OptimisationNotSupportedError: // fallthrough, slow way of getting the membership events for each event
default:
return err
return eventIDMembershipMap, err
}
response.Membership = make(map[string]*types.HeaderedEvent)
stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID], r)
eventIDMembershipMap = make(map[string]*types.HeaderedEvent)
stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, eventIDs, stateKeyNIDs[string(senderID)], r)
if err != nil {
return fmt.Errorf("unable to get state before event: %w", err)
return eventIDMembershipMap, fmt.Errorf("unable to get state before event: %w", err)
}
// If we only have one or less state entries, we can short circuit the below
// loop and avoid hitting the database
allStateEventNIDs := make(map[types.EventNID]types.StateEntry)
for _, eventID := range request.EventIDs {
for _, eventID := range eventIDs {
stateEntry := stateEntries[eventID]
for _, s := range stateEntry {
allStateEventNIDs[s.EventNID] = s
@ -344,10 +360,10 @@ func (r *Queryer) QueryMembershipAtEvent(
}
var memberships []types.Event
for _, eventID := range request.EventIDs {
for _, eventID := range eventIDs {
stateEntry, ok := stateEntries[eventID]
if !ok || len(stateEntry) == 0 {
response.Membership[eventID] = nil
eventIDMembershipMap[eventID] = nil
continue
}
@ -361,7 +377,7 @@ func (r *Queryer) QueryMembershipAtEvent(
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false)
}
if err != nil {
return fmt.Errorf("unable to get memberships at state: %w", err)
return eventIDMembershipMap, fmt.Errorf("unable to get memberships at state: %w", err)
}
// Iterate over all membership events we got. Given we only query the membership for
@ -369,13 +385,13 @@ func (r *Queryer) QueryMembershipAtEvent(
// a given event, overwrite any other existing membership events.
for i := range memberships {
ev := memberships[i]
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(request.UserID) {
response.Membership[eventID] = &types.HeaderedEvent{PDU: ev.PDU}
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) {
eventIDMembershipMap[eventID] = &types.HeaderedEvent{PDU: ev.PDU}
}
}
}
return nil
return eventIDMembershipMap, nil
}
// QueryMembershipsForRoom implements api.RoomserverInternalAPI
@ -830,13 +846,20 @@ func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentSt
return nil
}
func (r *Queryer) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error {
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, req.WantMembership)
func (r *Queryer) QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) {
roomIDStrs, err := r.DB.GetRoomsByMembership(ctx, userID, desiredMembership)
if err != nil {
return err
return nil, err
}
res.RoomIDs = roomIDs
return nil
roomIDs := make([]spec.RoomID, len(roomIDStrs))
for i, roomIDStr := range roomIDStrs {
roomID, err := spec.NewRoomID(roomIDStr)
if err != nil {
return nil, err
}
roomIDs[i] = *roomID
}
return roomIDs, nil
}
func (r *Queryer) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error {
@ -879,7 +902,12 @@ func (r *Queryer) QueryLeftUsers(ctx context.Context, req *api.QueryLeftUsersReq
}
func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join")
parsedUserID, err := spec.NewUserID(req.UserID, true)
if err != nil {
return err
}
roomIDs, err := r.DB.GetRoomsByMembership(ctx, *parsedUserID, "join")
if err != nil {
return err
}