From 2c87972a3a84be400e5c69e2e5a727f21b4e457e Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 13 Jun 2023 14:19:31 +0200 Subject: [PATCH] Create user room key if needed (#3108) --- roomserver/api/api.go | 4 +++ roomserver/internal/api.go | 21 ++++++++++++++++ .../internal/perform/perform_create_room.go | 25 ++++++++++++++++++- roomserver/internal/perform/perform_invite.go | 8 ++++++ roomserver/internal/perform/perform_join.go | 9 +++++++ roomserver/storage/shared/storage.go | 2 +- 6 files changed, 67 insertions(+), 2 deletions(-) diff --git a/roomserver/api/api.go b/roomserver/api/api.go index bafde91c..fec28841 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -2,6 +2,7 @@ package api import ( "context" + "crypto/ed25519" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" @@ -66,6 +67,9 @@ type RoomserverInternalAPI interface { req *QueryAuthChainRequest, res *QueryAuthChainResponse, ) error + + // GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created. + GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error) } type InputRoomEventsAPI interface { diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 35b7383a..4bcd3f3e 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -2,6 +2,7 @@ package internal import ( "context" + "crypto/ed25519" "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" @@ -270,3 +271,23 @@ func (r *RoomserverInternalAPI) PerformForget( ) error { return r.Forgetter.PerformForget(ctx, req, resp) } + +// GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created. +func (r *RoomserverInternalAPI) GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error) { + key, err := r.DB.SelectUserRoomPrivateKey(ctx, userID, roomID) + if err != nil { + return nil, err + } + // no key found, create one + if len(key) == 0 { + _, key, err = ed25519.GenerateKey(nil) + if err != nil { + return nil, err + } + key, err = r.DB.InsertUserRoomPrivatePublicKey(ctx, userID, roomID, key) + if err != nil { + return nil, err + } + } + return key, nil +} diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index 475418aa..121b257e 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -354,7 +354,30 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo SendAsServer: api.DoNotSendToOtherServers, }) } - if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs, false); err != nil { + + // first send the `m.room.create` event, so we have a roomNID + if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs[:1], false); err != nil { + util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + + // create user room key if needed + if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { + _, err = c.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("GetOrCreateUserRoomPrivateKey failed") + return "", &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + } + + // send the remaining events + if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs[1:], false); err != nil { util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") return "", &util.JSONResponse{ Code: http.StatusInternalServerError, diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 1440daad..cc2c5c19 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -183,6 +183,14 @@ func (r *Inviter) PerformInvite( inviteEvent = event } + // if we invited a local user, we can also create a user room key, if it doesn't exist yet. + if isTargetLocal && event.Version() == gomatrixserverlib.RoomVersionPseudoIDs { + _, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *invitedUser, *validRoomID) + if err != nil { + return fmt.Errorf("failed to get user room private key: %w", err) + } + } + // Send the invite event to the roomserver input stream. This will // notify existing users in the room about the invite, update the // membership table and ensure that the event is ready and available diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 83c3b7c3..74ed87c7 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -293,6 +293,15 @@ func (r *Joiner) performJoinRoomByID( switch err.(type) { case nil: + // create user room key if needed + if buildRes.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs { + _, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID) + if err != nil { + logrus.WithError(err).Error("GetOrCreateUserRoomPrivateKey failed") + return "", "", fmt.Errorf("failed to get user room private key: %w", err) + } + } + // The room join is local. Send the new join event into the // roomserver. First of all check that the user isn't already // a member of the room. This is best-effort (as in we won't diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 85a1ba7a..d7ca3cef 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1686,7 +1686,7 @@ func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userID spec.Use return rErr } if roomInfo == nil { - return nil + return eventutil.ErrRoomNoExists{} } key, sErr = d.UserRoomKeyTable.SelectUserRoomPrivateKey(ctx, txn, stateKeyNID, roomInfo.RoomNID)