mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-07-29 12:42:46 +00:00
Merge SenderID & Per Room User Key work (#3109)
This commit is contained in:
parent
7a2e325d10
commit
e4665979bf
75 changed files with 801 additions and 379 deletions
|
@ -74,6 +74,10 @@ func (r *Admin) PerformAdminEvacuateRoom(
|
|||
if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
validRoomID, err := spec.NewRoomID(roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
prevEvents := latestRes.LatestEvents
|
||||
var senderDomain spec.ServerName
|
||||
|
@ -100,7 +104,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
|
|||
PrevEvents: prevEvents,
|
||||
}
|
||||
|
||||
userID, err := r.Queryer.QueryUserIDForSender(ctx, roomID, spec.SenderID(fledglingEvent.SenderID))
|
||||
userID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(fledglingEvent.SenderID))
|
||||
if err != nil || userID == nil {
|
||||
continue
|
||||
}
|
||||
|
@ -264,16 +268,16 @@ func (r *Admin) PerformAdminDownloadState(
|
|||
return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err)
|
||||
}
|
||||
for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) {
|
||||
if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
||||
if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}); err != nil {
|
||||
continue
|
||||
}
|
||||
authEventMap[authEvent.EventID()] = authEvent
|
||||
}
|
||||
for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) {
|
||||
if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
||||
if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}); err != nil {
|
||||
continue
|
||||
}
|
||||
|
@ -293,7 +297,11 @@ func (r *Admin) PerformAdminDownloadState(
|
|||
stateIDs = append(stateIDs, stateEvent.EventID())
|
||||
}
|
||||
|
||||
senderID, err := r.Queryer.QuerySenderIDForUser(ctx, roomID, *fullUserID)
|
||||
validRoomID, err := spec.NewRoomID(roomID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
senderID, err := r.Queryer.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -42,6 +42,7 @@ type Backfiller struct {
|
|||
DB storage.Database
|
||||
FSAPI federationAPI.RoomserverFederationAPI
|
||||
KeyRing gomatrixserverlib.JSONVerifier
|
||||
Querier api.QuerySenderIDAPI
|
||||
|
||||
// The servers which should be preferred above other servers when backfilling
|
||||
PreferServers []spec.ServerName
|
||||
|
@ -79,7 +80,7 @@ func (r *Backfiller) PerformBackfill(
|
|||
}
|
||||
|
||||
// Scan the event tree for events to send back.
|
||||
resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName)
|
||||
resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName, r.Querier)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -113,7 +114,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
|
|||
if info == nil || info.IsStub() {
|
||||
return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID)
|
||||
}
|
||||
requester := newBackfillRequester(r.DB, r.FSAPI, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers)
|
||||
requester := newBackfillRequester(r.DB, r.FSAPI, r.Querier, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers)
|
||||
// Request 100 items regardless of what the query asks for.
|
||||
// We don't want to go much higher than this.
|
||||
// We can't honour exactly the limit as some sytests rely on requesting more for tests to pass
|
||||
|
@ -121,8 +122,8 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
|
|||
// Specifically the test "Outbound federation can backfill events"
|
||||
events, err := gomatrixserverlib.RequestBackfill(
|
||||
ctx, req.VirtualHost, requester,
|
||||
r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
||||
r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.Querier.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
},
|
||||
)
|
||||
// Only return an error if we really couldn't get any events.
|
||||
|
@ -135,7 +136,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
|
|||
logrus.WithError(err).WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events))
|
||||
|
||||
// persist these new events - auth checks have already been done
|
||||
roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events)
|
||||
roomNID, backfilledEventMap := persistEvents(ctx, r.DB, r.Querier, events)
|
||||
|
||||
for _, ev := range backfilledEventMap {
|
||||
// now add state for these events
|
||||
|
@ -212,8 +213,8 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
|
|||
continue
|
||||
}
|
||||
loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
|
||||
result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
||||
result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.Querier.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
})
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("failed to load and verify event")
|
||||
|
@ -246,13 +247,14 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
|
|||
}
|
||||
}
|
||||
util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents))
|
||||
persistEvents(ctx, r.DB, newEvents)
|
||||
persistEvents(ctx, r.DB, r.Querier, newEvents)
|
||||
}
|
||||
|
||||
// backfillRequester implements gomatrixserverlib.BackfillRequester
|
||||
type backfillRequester struct {
|
||||
db storage.Database
|
||||
fsAPI federationAPI.RoomserverFederationAPI
|
||||
querier api.QuerySenderIDAPI
|
||||
virtualHost spec.ServerName
|
||||
isLocalServerName func(spec.ServerName) bool
|
||||
preferServer map[spec.ServerName]bool
|
||||
|
@ -268,6 +270,7 @@ type backfillRequester struct {
|
|||
|
||||
func newBackfillRequester(
|
||||
db storage.Database, fsAPI federationAPI.RoomserverFederationAPI,
|
||||
querier api.QuerySenderIDAPI,
|
||||
virtualHost spec.ServerName,
|
||||
isLocalServerName func(spec.ServerName) bool,
|
||||
bwExtrems map[string][]string, preferServers []spec.ServerName,
|
||||
|
@ -279,6 +282,7 @@ func newBackfillRequester(
|
|||
return &backfillRequester{
|
||||
db: db,
|
||||
fsAPI: fsAPI,
|
||||
querier: querier,
|
||||
virtualHost: virtualHost,
|
||||
isLocalServerName: isLocalServerName,
|
||||
eventIDToBeforeStateIDs: make(map[string][]string),
|
||||
|
@ -460,14 +464,14 @@ FindSuccessor:
|
|||
return nil
|
||||
}
|
||||
|
||||
stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID].EventNID)
|
||||
stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID].EventNID, b.querier)
|
||||
if err != nil {
|
||||
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event")
|
||||
return nil
|
||||
}
|
||||
|
||||
// possibly return all joined servers depending on history visiblity
|
||||
memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, info, stateEntries, b.virtualHost)
|
||||
memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, b.querier, info, stateEntries, b.virtualHost)
|
||||
b.historyVisiblity = visibility
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules")
|
||||
|
@ -488,7 +492,11 @@ FindSuccessor:
|
|||
// Store the server names in a temporary map to avoid duplicates.
|
||||
serverSet := make(map[spec.ServerName]bool)
|
||||
for _, event := range memberEvents {
|
||||
if sender, err := b.db.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()); err == nil {
|
||||
validRoomID, err := spec.NewRoomID(event.RoomID())
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if sender, err := b.querier.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()); err == nil {
|
||||
serverSet[sender.Domain()] = true
|
||||
}
|
||||
}
|
||||
|
@ -554,7 +562,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
|
|||
// TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just
|
||||
// pull all events and then filter by that table.
|
||||
func joinEventsFromHistoryVisibility(
|
||||
ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry,
|
||||
ctx context.Context, db storage.RoomDatabase, querier api.QuerySenderIDAPI, roomInfo *types.RoomInfo, stateEntries []types.StateEntry,
|
||||
thisServer spec.ServerName) ([]types.Event, gomatrixserverlib.HistoryVisibility, error) {
|
||||
|
||||
var eventNIDs []types.EventNID
|
||||
|
@ -582,7 +590,7 @@ func joinEventsFromHistoryVisibility(
|
|||
}
|
||||
|
||||
// Can we see events in the room?
|
||||
canSeeEvents := auth.IsServerAllowed(ctx, db, thisServer, true, events)
|
||||
canSeeEvents := auth.IsServerAllowed(ctx, querier, thisServer, true, events)
|
||||
visibility := auth.HistoryVisibilityForRoom(events)
|
||||
if !canSeeEvents {
|
||||
logrus.Infof("ServersAtEvent history not visible to us: %s", visibility)
|
||||
|
@ -597,7 +605,7 @@ func joinEventsFromHistoryVisibility(
|
|||
return evs, visibility, err
|
||||
}
|
||||
|
||||
func persistEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.PDU) (types.RoomNID, map[string]types.Event) {
|
||||
func persistEvents(ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI, events []gomatrixserverlib.PDU) (types.RoomNID, map[string]types.Event) {
|
||||
var roomNID types.RoomNID
|
||||
var eventNID types.EventNID
|
||||
backfilledEventMap := make(map[string]types.Event)
|
||||
|
@ -639,7 +647,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []gomatrixse
|
|||
continue
|
||||
}
|
||||
|
||||
resolver := state.NewStateResolution(db, roomInfo)
|
||||
resolver := state.NewStateResolution(db, roomInfo, querier)
|
||||
|
||||
_, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev, &resolver)
|
||||
if err != nil {
|
||||
|
|
|
@ -63,13 +63,20 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
|
|||
}
|
||||
}
|
||||
}
|
||||
senderID, err := c.DB.GetSenderIDForUser(ctx, roomID.String(), userID)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user")
|
||||
return "", &util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
var senderID spec.SenderID
|
||||
if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
|
||||
// create user room key if needed
|
||||
key, keyErr := c.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID)
|
||||
if keyErr != nil {
|
||||
util.GetLogger(ctx).WithError(keyErr).Error("GetOrCreateUserRoomPrivateKey failed")
|
||||
return "", &util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
senderID = spec.SenderID(spec.Base64Bytes(key).Encode())
|
||||
} else {
|
||||
senderID = spec.SenderID(userID.String())
|
||||
}
|
||||
createContent["creator"] = senderID
|
||||
createContent["room_version"] = createRequest.RoomVersion
|
||||
|
@ -323,8 +330,8 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
|
|||
}
|
||||
}
|
||||
|
||||
if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return c.DB.GetUserIDForSender(ctx, roomID, senderID)
|
||||
if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return c.RSAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed")
|
||||
return "", &util.JSONResponse{
|
||||
|
@ -364,18 +371,6 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
|
|||
}
|
||||
}
|
||||
|
||||
// create user room key if needed
|
||||
if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
|
||||
_, err = c.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("GetOrCreateUserRoomPrivateKey failed")
|
||||
return "", &util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// send the remaining events
|
||||
if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs[1:], false); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed")
|
||||
|
@ -455,7 +450,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
|
|||
JSON: spec.InternalServerError{},
|
||||
}
|
||||
}
|
||||
inviteeSenderID, queryErr := c.RSAPI.QuerySenderIDForUser(ctx, roomID.String(), *inviteeUserID)
|
||||
inviteeSenderID, queryErr := c.RSAPI.QuerySenderIDForUser(ctx, roomID, *inviteeUserID)
|
||||
if queryErr != nil {
|
||||
util.GetLogger(ctx).WithError(queryErr).Error("rsapi.QuerySenderIDForUser failed")
|
||||
return "", &util.JSONResponse{
|
||||
|
|
|
@ -79,7 +79,7 @@ func (r *InboundPeeker) PerformInboundPeek(
|
|||
response.LatestEvent = &types.HeaderedEvent{PDU: sortedLatestEvents[0]}
|
||||
|
||||
// XXX: do we actually need to do a state resolution here?
|
||||
roomState := state.NewStateResolution(r.DB, info)
|
||||
roomState := state.NewStateResolution(r.DB, info, r.Inputer.Queryer)
|
||||
|
||||
var stateEntries []types.StateEntry
|
||||
stateEntries, err = roomState.LoadStateAtSnapshot(
|
||||
|
|
|
@ -34,6 +34,7 @@ import (
|
|||
|
||||
type QueryState struct {
|
||||
storage.Database
|
||||
querier api.QuerySenderIDAPI
|
||||
}
|
||||
|
||||
func (q *QueryState) GetAuthEvents(ctx context.Context, event gomatrixserverlib.PDU) (gomatrixserverlib.AuthEventProvider, error) {
|
||||
|
@ -46,7 +47,7 @@ func (q *QueryState) GetState(ctx context.Context, roomID spec.RoomID, stateWant
|
|||
return nil, fmt.Errorf("failed to load RoomInfo: %w", err)
|
||||
}
|
||||
if info != nil {
|
||||
roomState := state.NewStateResolution(q.Database, info)
|
||||
roomState := state.NewStateResolution(q.Database, info, q.querier)
|
||||
stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples(
|
||||
ctx, info.StateSnapshotNID(), stateWanted,
|
||||
)
|
||||
|
@ -98,7 +99,11 @@ func (r *Inviter) ProcessInviteMembership(
|
|||
var outputUpdates []api.OutputEvent
|
||||
var updater *shared.MembershipUpdater
|
||||
|
||||
userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey()))
|
||||
validRoomID, err := spec.NewRoomID(inviteEvent.RoomID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userID, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*inviteEvent.StateKey()))
|
||||
if err != nil {
|
||||
return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())}
|
||||
}
|
||||
|
@ -126,7 +131,12 @@ func (r *Inviter) PerformInvite(
|
|||
) error {
|
||||
event := req.Event
|
||||
|
||||
sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
||||
validRoomID, err := spec.NewRoomID(event.RoomID())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sender, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
|
||||
if err != nil {
|
||||
return spec.InvalidParam("The sender user ID is invalid")
|
||||
}
|
||||
|
@ -137,18 +147,13 @@ func (r *Inviter) PerformInvite(
|
|||
if event.StateKey() == nil || *event.StateKey() == "" {
|
||||
return fmt.Errorf("invite must be a state event")
|
||||
}
|
||||
invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
|
||||
invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey()))
|
||||
if err != nil || invitedUser == nil {
|
||||
return spec.InvalidParam("Could not find the matching senderID for this user")
|
||||
}
|
||||
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(invitedUser.Domain())
|
||||
|
||||
validRoomID, err := spec.NewRoomID(event.RoomID())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, event.RoomID(), *invitedUser)
|
||||
invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, *validRoomID, *invitedUser)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed looking up senderID for invited user")
|
||||
}
|
||||
|
@ -161,9 +166,9 @@ func (r *Inviter) PerformInvite(
|
|||
IsTargetLocal: isTargetLocal,
|
||||
StrippedState: req.InviteRoomState,
|
||||
MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI},
|
||||
StateQuerier: &QueryState{r.DB},
|
||||
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
||||
StateQuerier: &QueryState{r.DB, r.RSAPI},
|
||||
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.RSAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
},
|
||||
}
|
||||
inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI)
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"github.com/getsentry/sentry-go"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
|
@ -174,44 +175,6 @@ func (r *Joiner) performJoinRoomByID(
|
|||
req.ServerNames = append(req.ServerNames, roomID.Domain())
|
||||
}
|
||||
|
||||
// Prepare the template for the join event.
|
||||
userID, err := spec.NewUserID(req.UserID, true)
|
||||
if err != nil {
|
||||
return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)}
|
||||
}
|
||||
senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomIDOrAlias, *userID)
|
||||
if err != nil {
|
||||
return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)}
|
||||
}
|
||||
senderIDString := string(senderID)
|
||||
userDomain := userID.Domain()
|
||||
proto := gomatrixserverlib.ProtoEvent{
|
||||
Type: spec.MRoomMember,
|
||||
SenderID: senderIDString,
|
||||
StateKey: &senderIDString,
|
||||
RoomID: req.RoomIDOrAlias,
|
||||
Redacts: "",
|
||||
}
|
||||
if err = proto.SetUnsigned(struct{}{}); err != nil {
|
||||
return "", "", fmt.Errorf("eb.SetUnsigned: %w", err)
|
||||
}
|
||||
|
||||
// It is possible for the request to include some "content" for the
|
||||
// event. We'll always overwrite the "membership" key, but the rest,
|
||||
// like "display_name" or "avatar_url", will be kept if supplied.
|
||||
if req.Content == nil {
|
||||
req.Content = map[string]interface{}{}
|
||||
}
|
||||
req.Content["membership"] = spec.Join
|
||||
if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil {
|
||||
return "", "", aerr
|
||||
} else if authorisedVia != "" {
|
||||
req.Content["join_authorised_via_users_server"] = authorisedVia
|
||||
}
|
||||
if err = proto.SetContent(req.Content); err != nil {
|
||||
return "", "", fmt.Errorf("eb.SetContent: %w", err)
|
||||
}
|
||||
|
||||
// Force a federated join if we aren't in the room and we've been
|
||||
// given some server names to try joining by.
|
||||
inRoomReq := &rsAPI.QueryServerJoinedToRoomRequest{
|
||||
|
@ -224,29 +187,63 @@ func (r *Joiner) performJoinRoomByID(
|
|||
serverInRoom := inRoomRes.IsInRoom
|
||||
forceFederatedJoin := len(req.ServerNames) > 0 && !serverInRoom
|
||||
|
||||
userID, err := spec.NewUserID(req.UserID, true)
|
||||
if err != nil {
|
||||
return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)}
|
||||
}
|
||||
|
||||
// Look up the room NID for the supplied room ID.
|
||||
var senderID spec.SenderID
|
||||
checkInvitePending := false
|
||||
info, err := r.DB.RoomInfo(ctx, req.RoomIDOrAlias)
|
||||
if err == nil && info != nil {
|
||||
switch info.RoomVersion {
|
||||
case gomatrixserverlib.RoomVersionPseudoIDs:
|
||||
senderID, err = r.Queryer.QuerySenderIDForUser(ctx, *roomID, *userID)
|
||||
if err == nil {
|
||||
checkInvitePending = true
|
||||
} else {
|
||||
// create user room key if needed
|
||||
key, keyErr := r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID)
|
||||
if keyErr != nil {
|
||||
util.GetLogger(ctx).WithError(keyErr).Error("GetOrCreateUserRoomPrivateKey failed")
|
||||
return "", "", fmt.Errorf("GetOrCreateUserRoomPrivateKey failed: %w", keyErr)
|
||||
}
|
||||
senderID = spec.SenderID(spec.Base64Bytes(key).Encode())
|
||||
}
|
||||
default:
|
||||
checkInvitePending = true
|
||||
senderID = spec.SenderID(userID.String())
|
||||
}
|
||||
}
|
||||
|
||||
userDomain := userID.Domain()
|
||||
|
||||
// Force a federated join if we're dealing with a pending invite
|
||||
// and we aren't in the room.
|
||||
isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID)
|
||||
if err == nil && !serverInRoom && isInvitePending {
|
||||
inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomIDOrAlias, inviteSender)
|
||||
if queryErr != nil {
|
||||
return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr)
|
||||
}
|
||||
if checkInvitePending {
|
||||
isInvitePending, inviteSender, _, inviteEvent, inviteErr := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID)
|
||||
if inviteErr == nil && !serverInRoom && isInvitePending {
|
||||
inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, inviteSender)
|
||||
if queryErr != nil {
|
||||
return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr)
|
||||
}
|
||||
|
||||
// If we were invited by someone from another server then we can
|
||||
// assume they are in the room so we can join via them.
|
||||
if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) {
|
||||
req.ServerNames = append(req.ServerNames, inviter.Domain())
|
||||
forceFederatedJoin = true
|
||||
memberEvent := gjson.Parse(string(inviteEvent.JSON()))
|
||||
// only set unsigned if we've got a content.membership, which we _should_
|
||||
if memberEvent.Get("content.membership").Exists() {
|
||||
req.Unsigned = map[string]interface{}{
|
||||
"prev_sender": memberEvent.Get("sender").Str,
|
||||
"prev_content": map[string]interface{}{
|
||||
"is_direct": memberEvent.Get("content.is_direct").Bool(),
|
||||
"membership": memberEvent.Get("content.membership").Str,
|
||||
},
|
||||
// If we were invited by someone from another server then we can
|
||||
// assume they are in the room so we can join via them.
|
||||
if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) {
|
||||
req.ServerNames = append(req.ServerNames, inviter.Domain())
|
||||
forceFederatedJoin = true
|
||||
memberEvent := gjson.Parse(string(inviteEvent.JSON()))
|
||||
// only set unsigned if we've got a content.membership, which we _should_
|
||||
if memberEvent.Get("content.membership").Exists() {
|
||||
req.Unsigned = map[string]interface{}{
|
||||
"prev_sender": memberEvent.Get("sender").Str,
|
||||
"prev_content": map[string]interface{}{
|
||||
"is_direct": memberEvent.Get("content.is_direct").Bool(),
|
||||
"membership": memberEvent.Get("content.membership").Str,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -274,6 +271,7 @@ func (r *Joiner) performJoinRoomByID(
|
|||
// If we should do a forced federated join then do that.
|
||||
var joinedVia spec.ServerName
|
||||
if forceFederatedJoin {
|
||||
// TODO : pseudoIDs - pass through userID here since we don't know what the senderID should be yet
|
||||
joinedVia, err = r.performFederatedJoinRoomByID(ctx, req)
|
||||
return req.RoomIDOrAlias, joinedVia, err
|
||||
}
|
||||
|
@ -289,19 +287,40 @@ func (r *Joiner) performJoinRoomByID(
|
|||
if err != nil {
|
||||
return "", "", fmt.Errorf("error joining local room: %q", err)
|
||||
}
|
||||
|
||||
senderIDString := string(senderID)
|
||||
|
||||
// Prepare the template for the join event.
|
||||
proto := gomatrixserverlib.ProtoEvent{
|
||||
Type: spec.MRoomMember,
|
||||
SenderID: senderIDString,
|
||||
StateKey: &senderIDString,
|
||||
RoomID: req.RoomIDOrAlias,
|
||||
Redacts: "",
|
||||
}
|
||||
if err = proto.SetUnsigned(struct{}{}); err != nil {
|
||||
return "", "", fmt.Errorf("eb.SetUnsigned: %w", err)
|
||||
}
|
||||
|
||||
// It is possible for the request to include some "content" for the
|
||||
// event. We'll always overwrite the "membership" key, but the rest,
|
||||
// like "display_name" or "avatar_url", will be kept if supplied.
|
||||
if req.Content == nil {
|
||||
req.Content = map[string]interface{}{}
|
||||
}
|
||||
req.Content["membership"] = spec.Join
|
||||
if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil {
|
||||
return "", "", aerr
|
||||
} else if authorisedVia != "" {
|
||||
req.Content["join_authorised_via_users_server"] = authorisedVia
|
||||
}
|
||||
if err = proto.SetContent(req.Content); err != nil {
|
||||
return "", "", fmt.Errorf("eb.SetContent: %w", err)
|
||||
}
|
||||
event, err := eventutil.QueryAndBuildEvent(ctx, &proto, identity, time.Now(), r.RSAPI, &buildRes)
|
||||
|
||||
switch err.(type) {
|
||||
case nil:
|
||||
// create user room key if needed
|
||||
if buildRes.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
|
||||
_, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Error("GetOrCreateUserRoomPrivateKey failed")
|
||||
return "", "", fmt.Errorf("failed to get user room private key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// The room join is local. Send the new join event into the
|
||||
// roomserver. First of all check that the user isn't already
|
||||
// a member of the room. This is best-effort (as in we won't
|
||||
|
|
|
@ -78,7 +78,11 @@ func (r *Leaver) performLeaveRoomByID(
|
|||
req *api.PerformLeaveRequest,
|
||||
res *api.PerformLeaveResponse, // nolint:unparam
|
||||
) ([]api.OutputEvent, error) {
|
||||
leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, req.Leaver)
|
||||
roomID, err := spec.NewRoomID(req.RoomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, *roomID, req.Leaver)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String())
|
||||
}
|
||||
|
@ -87,7 +91,7 @@ func (r *Leaver) performLeaveRoomByID(
|
|||
// that.
|
||||
isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, leaver)
|
||||
if err == nil && isInvitePending {
|
||||
sender, serr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomID, senderUser)
|
||||
sender, serr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, senderUser)
|
||||
if serr != nil || sender == nil {
|
||||
return nil, fmt.Errorf("sender %q has no matching userID", senderUser)
|
||||
}
|
||||
|
@ -133,7 +137,7 @@ func (r *Leaver) performLeaveRoomByID(
|
|||
},
|
||||
}
|
||||
latestRes := api.QueryLatestEventsAndStateResponse{}
|
||||
if err = helpers.QueryLatestEventsAndState(ctx, r.DB, &latestReq, &latestRes); err != nil {
|
||||
if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r.RSAPI, &latestReq, &latestRes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !latestRes.RoomExists {
|
||||
|
|
|
@ -54,7 +54,11 @@ func (r *Upgrader) performRoomUpgrade(
|
|||
return "", err
|
||||
}
|
||||
|
||||
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, userID)
|
||||
fullRoomID, err := spec.NewRoomID(roomID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, *fullRoomID, userID)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user")
|
||||
return "", err
|
||||
|
@ -488,7 +492,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, send
|
|||
|
||||
}
|
||||
|
||||
if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}); err != nil {
|
||||
return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err)
|
||||
|
@ -569,7 +573,7 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, send
|
|||
stateEvents[i] = queryRes.StateEvents[i].PDU
|
||||
}
|
||||
provider := gomatrixserverlib.NewAuthEvents(stateEvents)
|
||||
if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}); err != nil {
|
||||
return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client?
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue