mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-07-31 13:22:46 +00:00
[pseudoID] More pseudo ID fixes (#3167)
Signed-off-by: `Sam Wedgwood <sam@wedgwood.dev>`
This commit is contained in:
parent
fa6c7ba456
commit
9a12420428
24 changed files with 472 additions and 237 deletions
|
@ -161,12 +161,12 @@ func (r *Admin) PerformAdminEvacuateUser(
|
|||
return nil, fmt.Errorf("can only evacuate local users using this endpoint")
|
||||
}
|
||||
|
||||
roomIDs, err := r.DB.GetRoomsByMembership(ctx, userID, spec.Join)
|
||||
roomIDs, err := r.DB.GetRoomsByMembership(ctx, *fullUserID, spec.Join)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, userID, spec.Invite)
|
||||
inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, *fullUserID, spec.Invite)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -230,6 +230,33 @@ func (r *Queryer) QueryMembershipForSenderID(
|
|||
senderID spec.SenderID,
|
||||
response *api.QueryMembershipForUserResponse,
|
||||
) error {
|
||||
return r.queryMembershipForOptionalSenderID(ctx, roomID, &senderID, response)
|
||||
}
|
||||
|
||||
// QueryMembershipForUser implements api.RoomserverInternalAPI
|
||||
func (r *Queryer) QueryMembershipForUser(
|
||||
ctx context.Context,
|
||||
request *api.QueryMembershipForUserRequest,
|
||||
response *api.QueryMembershipForUserResponse,
|
||||
) error {
|
||||
roomID, err := spec.NewRoomID(request.RoomID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
senderID, err := r.QuerySenderIDForUser(ctx, *roomID, request.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return r.queryMembershipForOptionalSenderID(ctx, *roomID, senderID, response)
|
||||
}
|
||||
|
||||
// Query membership information for provided sender ID and room ID
|
||||
//
|
||||
// If sender ID is nil, then act as if the provided sender is not a member of the room.
|
||||
func (r *Queryer) queryMembershipForOptionalSenderID(ctx context.Context, roomID spec.RoomID, senderID *spec.SenderID, response *api.QueryMembershipForUserResponse) error {
|
||||
response.SenderID = senderID
|
||||
|
||||
info, err := r.DB.RoomInfo(ctx, roomID.String())
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -240,7 +267,11 @@ func (r *Queryer) QueryMembershipForSenderID(
|
|||
}
|
||||
response.RoomExists = true
|
||||
|
||||
membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, senderID)
|
||||
if senderID == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, *senderID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -268,70 +299,55 @@ func (r *Queryer) QueryMembershipForSenderID(
|
|||
return err
|
||||
}
|
||||
|
||||
// QueryMembershipForUser implements api.RoomserverInternalAPI
|
||||
func (r *Queryer) QueryMembershipForUser(
|
||||
ctx context.Context,
|
||||
request *api.QueryMembershipForUserRequest,
|
||||
response *api.QueryMembershipForUserResponse,
|
||||
) error {
|
||||
roomID, err := spec.NewRoomID(request.RoomID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
senderID, err := r.QuerySenderIDForUser(ctx, *roomID, request.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return r.QueryMembershipForSenderID(ctx, *roomID, *senderID, response)
|
||||
}
|
||||
|
||||
// QueryMembershipAtEvent returns the known memberships at a given event.
|
||||
// If the state before an event is not known, an empty list will be returned
|
||||
// for that event instead.
|
||||
//
|
||||
// Returned map from eventID to membership event. Events that
|
||||
// do not have known state will return a nil event, resulting in a "leave" membership
|
||||
// when calculating history visibility.
|
||||
func (r *Queryer) QueryMembershipAtEvent(
|
||||
ctx context.Context,
|
||||
request *api.QueryMembershipAtEventRequest,
|
||||
response *api.QueryMembershipAtEventResponse,
|
||||
) error {
|
||||
response.Membership = make(map[string]*types.HeaderedEvent)
|
||||
|
||||
info, err := r.DB.RoomInfo(ctx, request.RoomID)
|
||||
roomID spec.RoomID,
|
||||
eventIDs []string,
|
||||
senderID spec.SenderID,
|
||||
) (map[string]*types.HeaderedEvent, error) {
|
||||
info, err := r.DB.RoomInfo(ctx, roomID.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get roomInfo: %w", err)
|
||||
return nil, fmt.Errorf("unable to get roomInfo: %w", err)
|
||||
}
|
||||
if info == nil {
|
||||
return fmt.Errorf("no roomInfo found")
|
||||
return nil, fmt.Errorf("no roomInfo found")
|
||||
}
|
||||
|
||||
// get the users stateKeyNID
|
||||
stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.UserID})
|
||||
stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{string(senderID)})
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get stateKeyNIDs for %s: %w", request.UserID, err)
|
||||
return nil, fmt.Errorf("unable to get stateKeyNIDs for %s: %w", senderID, err)
|
||||
}
|
||||
if _, ok := stateKeyNIDs[request.UserID]; !ok {
|
||||
return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID)
|
||||
if _, ok := stateKeyNIDs[string(senderID)]; !ok {
|
||||
return nil, fmt.Errorf("requested stateKeyNID for %s was not found", senderID)
|
||||
}
|
||||
|
||||
response.Membership, err = r.DB.GetMembershipForHistoryVisibility(ctx, stateKeyNIDs[request.UserID], info, request.EventIDs...)
|
||||
eventIDMembershipMap, err := r.DB.GetMembershipForHistoryVisibility(ctx, stateKeyNIDs[string(senderID)], info, eventIDs...)
|
||||
switch err {
|
||||
case nil:
|
||||
return nil
|
||||
return eventIDMembershipMap, nil
|
||||
case tables.OptimisationNotSupportedError: // fallthrough, slow way of getting the membership events for each event
|
||||
default:
|
||||
return err
|
||||
return eventIDMembershipMap, err
|
||||
}
|
||||
|
||||
response.Membership = make(map[string]*types.HeaderedEvent)
|
||||
stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID], r)
|
||||
eventIDMembershipMap = make(map[string]*types.HeaderedEvent)
|
||||
stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, eventIDs, stateKeyNIDs[string(senderID)], r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get state before event: %w", err)
|
||||
return eventIDMembershipMap, fmt.Errorf("unable to get state before event: %w", err)
|
||||
}
|
||||
|
||||
// If we only have one or less state entries, we can short circuit the below
|
||||
// loop and avoid hitting the database
|
||||
allStateEventNIDs := make(map[types.EventNID]types.StateEntry)
|
||||
for _, eventID := range request.EventIDs {
|
||||
for _, eventID := range eventIDs {
|
||||
stateEntry := stateEntries[eventID]
|
||||
for _, s := range stateEntry {
|
||||
allStateEventNIDs[s.EventNID] = s
|
||||
|
@ -344,10 +360,10 @@ func (r *Queryer) QueryMembershipAtEvent(
|
|||
}
|
||||
|
||||
var memberships []types.Event
|
||||
for _, eventID := range request.EventIDs {
|
||||
for _, eventID := range eventIDs {
|
||||
stateEntry, ok := stateEntries[eventID]
|
||||
if !ok || len(stateEntry) == 0 {
|
||||
response.Membership[eventID] = nil
|
||||
eventIDMembershipMap[eventID] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -361,7 +377,7 @@ func (r *Queryer) QueryMembershipAtEvent(
|
|||
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get memberships at state: %w", err)
|
||||
return eventIDMembershipMap, fmt.Errorf("unable to get memberships at state: %w", err)
|
||||
}
|
||||
|
||||
// Iterate over all membership events we got. Given we only query the membership for
|
||||
|
@ -369,13 +385,13 @@ func (r *Queryer) QueryMembershipAtEvent(
|
|||
// a given event, overwrite any other existing membership events.
|
||||
for i := range memberships {
|
||||
ev := memberships[i]
|
||||
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(request.UserID) {
|
||||
response.Membership[eventID] = &types.HeaderedEvent{PDU: ev.PDU}
|
||||
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) {
|
||||
eventIDMembershipMap[eventID] = &types.HeaderedEvent{PDU: ev.PDU}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return eventIDMembershipMap, nil
|
||||
}
|
||||
|
||||
// QueryMembershipsForRoom implements api.RoomserverInternalAPI
|
||||
|
@ -830,13 +846,20 @@ func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentSt
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *Queryer) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error {
|
||||
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, req.WantMembership)
|
||||
func (r *Queryer) QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) {
|
||||
roomIDStrs, err := r.DB.GetRoomsByMembership(ctx, userID, desiredMembership)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
res.RoomIDs = roomIDs
|
||||
return nil
|
||||
roomIDs := make([]spec.RoomID, len(roomIDStrs))
|
||||
for i, roomIDStr := range roomIDStrs {
|
||||
roomID, err := spec.NewRoomID(roomIDStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roomIDs[i] = *roomID
|
||||
}
|
||||
return roomIDs, nil
|
||||
}
|
||||
|
||||
func (r *Queryer) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error {
|
||||
|
@ -879,7 +902,12 @@ func (r *Queryer) QueryLeftUsers(ctx context.Context, req *api.QueryLeftUsersReq
|
|||
}
|
||||
|
||||
func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
|
||||
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join")
|
||||
parsedUserID, err := spec.NewUserID(req.UserID, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
roomIDs, err := r.DB.GetRoomsByMembership(ctx, *parsedUserID, "join")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue