mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-26 15:08:28 +00:00
Key Backups (3/3) : Implement querying keys and various bugfixes (#1946)
* Add querying device keys Makes a bunch of sytests pass * Apparently only the current version supports uploading keys * Linting
This commit is contained in:
parent
b3754d68fc
commit
32bf14a37c
12 changed files with 362 additions and 101 deletions
|
@ -37,7 +37,7 @@ type keyBackupVersionCreateResponse struct {
|
||||||
type keyBackupVersionResponse struct {
|
type keyBackupVersionResponse struct {
|
||||||
Algorithm string `json:"algorithm"`
|
Algorithm string `json:"algorithm"`
|
||||||
AuthData json.RawMessage `json:"auth_data"`
|
AuthData json.RawMessage `json:"auth_data"`
|
||||||
Count int `json:"count"`
|
Count int64 `json:"count"`
|
||||||
ETag string `json:"etag"`
|
ETag string `json:"etag"`
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
}
|
}
|
||||||
|
@ -89,7 +89,10 @@ func CreateKeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI,
|
||||||
// Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version}
|
// Implements GET /_matrix/client/r0/room_keys/version and GET /_matrix/client/r0/room_keys/version/{version}
|
||||||
func KeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse {
|
func KeyBackupVersion(req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version string) util.JSONResponse {
|
||||||
var queryResp userapi.QueryKeyBackupResponse
|
var queryResp userapi.QueryKeyBackupResponse
|
||||||
userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{}, &queryResp)
|
userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
|
||||||
|
UserID: device.UserID,
|
||||||
|
Version: version,
|
||||||
|
}, &queryResp)
|
||||||
if queryResp.Error != "" {
|
if queryResp.Error != "" {
|
||||||
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error))
|
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error))
|
||||||
}
|
}
|
||||||
|
@ -216,3 +219,73 @@ func UploadBackupKeys(
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get keys from a given backup version. Response returned varies depending on if roomID and sessionID are set.
|
||||||
|
func GetBackupKeys(
|
||||||
|
req *http.Request, userAPI userapi.UserInternalAPI, device *userapi.Device, version, roomID, sessionID string,
|
||||||
|
) util.JSONResponse {
|
||||||
|
var queryResp userapi.QueryKeyBackupResponse
|
||||||
|
userAPI.QueryKeyBackup(req.Context(), &userapi.QueryKeyBackupRequest{
|
||||||
|
UserID: device.UserID,
|
||||||
|
Version: version,
|
||||||
|
ReturnKeys: true,
|
||||||
|
KeysForRoomID: roomID,
|
||||||
|
KeysForSessionID: sessionID,
|
||||||
|
}, &queryResp)
|
||||||
|
if queryResp.Error != "" {
|
||||||
|
return util.ErrorResponse(fmt.Errorf("QueryKeyBackup: %s", queryResp.Error))
|
||||||
|
}
|
||||||
|
if !queryResp.Exists {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 404,
|
||||||
|
JSON: jsonerror.NotFound("version not found"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if sessionID != "" {
|
||||||
|
// return the key itself if it was found
|
||||||
|
roomData, ok := queryResp.Keys[roomID]
|
||||||
|
if ok {
|
||||||
|
key, ok := roomData[sessionID]
|
||||||
|
if ok {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: key,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if roomID != "" {
|
||||||
|
roomData, ok := queryResp.Keys[roomID]
|
||||||
|
if ok {
|
||||||
|
// wrap response in "sessions"
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: struct {
|
||||||
|
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
|
||||||
|
}{
|
||||||
|
Sessions: roomData,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// response is the same as the upload request
|
||||||
|
var resp keyBackupSessionRequest
|
||||||
|
resp.Rooms = make(map[string]struct {
|
||||||
|
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
|
||||||
|
})
|
||||||
|
for roomID, roomData := range queryResp.Keys {
|
||||||
|
resp.Rooms[roomID] = struct {
|
||||||
|
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
|
||||||
|
}{
|
||||||
|
Sessions: roomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: resp,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 404,
|
||||||
|
JSON: jsonerror.NotFound("keys not found"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -896,11 +896,15 @@ func Setup(
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
// Key Backup Versions
|
// Key Backup Versions (Metadata)
|
||||||
r0mux.Handle("/room_keys/version/{versionID}",
|
|
||||||
|
r0mux.Handle("/room_keys/version/{version}",
|
||||||
httputil.MakeAuthAPI("get_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("get_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
version := req.URL.Query().Get("version")
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
return KeyBackupVersion(req, userAPI, device, version)
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return KeyBackupVersion(req, userAPI, device, vars["version"])
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
r0mux.Handle("/room_keys/version",
|
r0mux.Handle("/room_keys/version",
|
||||||
|
@ -908,28 +912,22 @@ func Setup(
|
||||||
return KeyBackupVersion(req, userAPI, device, "")
|
return KeyBackupVersion(req, userAPI, device, "")
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
r0mux.Handle("/room_keys/version/{versionID}",
|
r0mux.Handle("/room_keys/version/{version}",
|
||||||
httputil.MakeAuthAPI("put_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("put_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
version := req.URL.Query().Get("version")
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
if version == "" {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.ErrorResponse(err)
|
||||||
Code: 400,
|
|
||||||
JSON: jsonerror.InvalidArgumentValue("version must be specified"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return ModifyKeyBackupVersionAuthData(req, userAPI, device, version)
|
return ModifyKeyBackupVersionAuthData(req, userAPI, device, vars["version"])
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPut)
|
).Methods(http.MethodPut)
|
||||||
r0mux.Handle("/room_keys/version/{versionID}",
|
r0mux.Handle("/room_keys/version/{version}",
|
||||||
httputil.MakeAuthAPI("delete_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("delete_backup_keys_version", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
version := req.URL.Query().Get("version")
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
if version == "" {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.ErrorResponse(err)
|
||||||
Code: 400,
|
|
||||||
JSON: jsonerror.InvalidArgumentValue("version must be specified"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return DeleteKeyBackupVersion(req, userAPI, device, version)
|
return DeleteKeyBackupVersion(req, userAPI, device, vars["version"])
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodDelete)
|
).Methods(http.MethodDelete)
|
||||||
r0mux.Handle("/room_keys/version",
|
r0mux.Handle("/room_keys/version",
|
||||||
|
@ -938,7 +936,8 @@ func Setup(
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
// E2E Backup Keys
|
// Inserting E2E Backup Keys
|
||||||
|
|
||||||
// Bulk room and session
|
// Bulk room and session
|
||||||
r0mux.Handle("/room_keys/keys",
|
r0mux.Handle("/room_keys/keys",
|
||||||
httputil.MakeAuthAPI("put_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("put_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
@ -973,6 +972,9 @@ func Setup(
|
||||||
}
|
}
|
||||||
roomID := vars["roomID"]
|
roomID := vars["roomID"]
|
||||||
var reqBody keyBackupSessionRequest
|
var reqBody keyBackupSessionRequest
|
||||||
|
reqBody.Rooms = make(map[string]struct {
|
||||||
|
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
|
||||||
|
})
|
||||||
reqBody.Rooms[roomID] = struct {
|
reqBody.Rooms[roomID] = struct {
|
||||||
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
|
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
|
||||||
}{
|
}{
|
||||||
|
@ -989,7 +991,7 @@ func Setup(
|
||||||
).Methods(http.MethodPut)
|
).Methods(http.MethodPut)
|
||||||
// Single room, single session
|
// Single room, single session
|
||||||
r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}",
|
r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}",
|
||||||
httputil.MakeAuthAPI("put_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("put_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
|
@ -1009,14 +1011,47 @@ func Setup(
|
||||||
roomID := vars["roomID"]
|
roomID := vars["roomID"]
|
||||||
sessionID := vars["sessionID"]
|
sessionID := vars["sessionID"]
|
||||||
var keyReq keyBackupSessionRequest
|
var keyReq keyBackupSessionRequest
|
||||||
|
keyReq.Rooms = make(map[string]struct {
|
||||||
|
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
|
||||||
|
})
|
||||||
keyReq.Rooms[roomID] = struct {
|
keyReq.Rooms[roomID] = struct {
|
||||||
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
|
Sessions map[string]userapi.KeyBackupSession `json:"sessions"`
|
||||||
}{}
|
}{
|
||||||
|
Sessions: make(map[string]userapi.KeyBackupSession),
|
||||||
|
}
|
||||||
keyReq.Rooms[roomID].Sessions[sessionID] = reqBody
|
keyReq.Rooms[roomID].Sessions[sessionID] = reqBody
|
||||||
return UploadBackupKeys(req, userAPI, device, version, &keyReq)
|
return UploadBackupKeys(req, userAPI, device, version, &keyReq)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPut)
|
).Methods(http.MethodPut)
|
||||||
|
|
||||||
|
// Querying E2E Backup Keys
|
||||||
|
|
||||||
|
r0mux.Handle("/room_keys/keys",
|
||||||
|
httputil.MakeAuthAPI("get_backup_keys", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), "", "")
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
r0mux.Handle("/room_keys/keys/{roomID}",
|
||||||
|
httputil.MakeAuthAPI("get_backup_keys_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], "")
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
r0mux.Handle("/room_keys/keys/{roomID}/{sessionID}",
|
||||||
|
httputil.MakeAuthAPI("get_backup_keys_room_session", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return GetBackupKeys(req, userAPI, device, req.URL.Query().Get("version"), vars["roomID"], vars["sessionID"])
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
|
// Deleting E2E Backup Keys
|
||||||
|
|
||||||
// Supplying a device ID is deprecated.
|
// Supplying a device ID is deprecated.
|
||||||
r0mux.Handle("/keys/upload/{deviceID}",
|
r0mux.Handle("/keys/upload/{deviceID}",
|
||||||
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
|
|
@ -540,3 +540,13 @@ Key notary server must not overwrite a valid key with a spurious result from the
|
||||||
GET /rooms/:room_id/aliases lists aliases
|
GET /rooms/:room_id/aliases lists aliases
|
||||||
Only room members can list aliases of a room
|
Only room members can list aliases of a room
|
||||||
Users with sufficient power-level can delete other's aliases
|
Users with sufficient power-level can delete other's aliases
|
||||||
|
Can create backup version
|
||||||
|
Can update backup version
|
||||||
|
Responds correctly when backup is empty
|
||||||
|
Can backup keys
|
||||||
|
Can update keys with better versions
|
||||||
|
Will not update keys with worse versions
|
||||||
|
Will not back up to an old backup version
|
||||||
|
Can create more than 10 backup versions
|
||||||
|
Can delete backup
|
||||||
|
Deleted & recreated backups are empty
|
||||||
|
|
|
@ -67,6 +67,23 @@ type KeyBackupSession struct {
|
||||||
SessionData json.RawMessage `json:"session_data"`
|
SessionData json.RawMessage `json:"session_data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *KeyBackupSession) ShouldReplaceRoomKey(newKey *KeyBackupSession) bool {
|
||||||
|
// https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
|
||||||
|
// "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true"
|
||||||
|
if newKey.IsVerified && !a.IsVerified {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// "if they have the same values for is_verified, then it will keep the key with a lower first_message_index"
|
||||||
|
if newKey.FirstMessageIndex < a.FirstMessageIndex {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count"
|
||||||
|
if newKey.ForwardedCount < a.ForwardedCount {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// Internal KeyBackupData for passing to/from the storage layer
|
// Internal KeyBackupData for passing to/from the storage layer
|
||||||
type InternalKeyBackupSession struct {
|
type InternalKeyBackupSession struct {
|
||||||
KeyBackupSession
|
KeyBackupSession
|
||||||
|
@ -88,6 +105,10 @@ type PerformKeyBackupResponse struct {
|
||||||
type QueryKeyBackupRequest struct {
|
type QueryKeyBackupRequest struct {
|
||||||
UserID string
|
UserID string
|
||||||
Version string // the version to query, if blank it means the latest
|
Version string // the version to query, if blank it means the latest
|
||||||
|
|
||||||
|
ReturnKeys bool // whether to return keys in the backup response or just the metadata
|
||||||
|
KeysForRoomID string // optional string to return keys which belong to this room
|
||||||
|
KeysForSessionID string // optional string to return keys which belong to this (room, session)
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryKeyBackupResponse struct {
|
type QueryKeyBackupResponse struct {
|
||||||
|
@ -96,9 +117,11 @@ type QueryKeyBackupResponse struct {
|
||||||
|
|
||||||
Algorithm string `json:"algorithm"`
|
Algorithm string `json:"algorithm"`
|
||||||
AuthData json.RawMessage `json:"auth_data"`
|
AuthData json.RawMessage `json:"auth_data"`
|
||||||
Count int `json:"count"`
|
Count int64 `json:"count"`
|
||||||
ETag string `json:"etag"`
|
ETag string `json:"etag"`
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
|
|
||||||
|
Keys map[string]map[string]KeyBackupSession // the keys if ReturnKeys=true
|
||||||
}
|
}
|
||||||
|
|
||||||
// InputAccountDataRequest is the request for InputAccountData
|
// InputAccountDataRequest is the request for InputAccountData
|
||||||
|
|
|
@ -475,6 +475,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = fmt.Sprintf("failed to update backup: %s", err)
|
res.Error = fmt.Sprintf("failed to update backup: %s", err)
|
||||||
}
|
}
|
||||||
|
res.Exists = err == nil
|
||||||
res.Version = req.Version
|
res.Version = req.Version
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -483,8 +484,8 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
|
func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
|
||||||
// ensure the version metadata exists
|
// you can only upload keys for the CURRENT version
|
||||||
version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version)
|
version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = fmt.Sprintf("failed to query version: %s", err)
|
res.Error = fmt.Sprintf("failed to query version: %s", err)
|
||||||
return
|
return
|
||||||
|
@ -493,6 +494,11 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
|
||||||
res.Error = "backup was deleted"
|
res.Error = "backup was deleted"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if version != req.Version {
|
||||||
|
res.BadInput = true
|
||||||
|
res.Error = fmt.Sprintf("%s isn't the current version, %s is.", req.Version, version)
|
||||||
|
return
|
||||||
|
}
|
||||||
res.Exists = true
|
res.Exists = true
|
||||||
res.Version = version
|
res.Version = version
|
||||||
|
|
||||||
|
@ -529,9 +535,21 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
|
||||||
}
|
}
|
||||||
res.Algorithm = algorithm
|
res.Algorithm = algorithm
|
||||||
res.AuthData = authData
|
res.AuthData = authData
|
||||||
|
res.ETag = etag
|
||||||
res.Exists = !deleted
|
res.Exists = !deleted
|
||||||
|
|
||||||
// TODO:
|
if !req.ReturnKeys {
|
||||||
res.Count = 0
|
res.Count, err = a.AccountDB.CountBackupKeys(ctx, version, req.UserID)
|
||||||
res.ETag = etag
|
if err != nil {
|
||||||
|
res.Error = fmt.Sprintf("failed to count keys: %s", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := a.AccountDB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
|
||||||
|
if err != nil {
|
||||||
|
res.Error = fmt.Sprintf("failed to query keys: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res.Keys = result
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,6 +61,8 @@ type Database interface {
|
||||||
DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error)
|
DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error)
|
||||||
GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error)
|
GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error)
|
||||||
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
|
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
|
||||||
|
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
|
||||||
|
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Err3PIDInUse is the error returned when trying to save an association involving
|
// Err3PIDInUse is the error returned when trying to save an association involving
|
||||||
|
|
|
@ -35,7 +35,8 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
|
||||||
is_verified BOOLEAN NOT NULL,
|
is_verified BOOLEAN NOT NULL,
|
||||||
session_data TEXT NOT NULL
|
session_data TEXT NOT NULL
|
||||||
);
|
);
|
||||||
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id);
|
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertBackupKeySQL = "" +
|
const insertBackupKeySQL = "" +
|
||||||
|
@ -53,14 +54,23 @@ const selectKeysSQL = "" +
|
||||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||||
"WHERE user_id = $1 AND version = $2"
|
"WHERE user_id = $1 AND version = $2"
|
||||||
|
|
||||||
|
const selectKeysByRoomIDSQL = "" +
|
||||||
|
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||||
|
"WHERE user_id = $1 AND version = $2 AND room_id = $3"
|
||||||
|
|
||||||
|
const selectKeysByRoomIDAndSessionIDSQL = "" +
|
||||||
|
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||||
|
"WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
|
||||||
|
|
||||||
type keyBackupStatements struct {
|
type keyBackupStatements struct {
|
||||||
insertBackupKeyStmt *sql.Stmt
|
insertBackupKeyStmt *sql.Stmt
|
||||||
updateBackupKeyStmt *sql.Stmt
|
updateBackupKeyStmt *sql.Stmt
|
||||||
countKeysStmt *sql.Stmt
|
countKeysStmt *sql.Stmt
|
||||||
selectKeysStmt *sql.Stmt
|
selectKeysStmt *sql.Stmt
|
||||||
|
selectKeysByRoomIDStmt *sql.Stmt
|
||||||
|
selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint:unused
|
|
||||||
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
|
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
|
||||||
_, err = db.Exec(keyBackupTableSchema)
|
_, err = db.Exec(keyBackupTableSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -78,6 +88,12 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
|
||||||
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
|
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.selectKeysByRoomIDStmt, err = db.Prepare(selectKeysByRoomIDSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectKeysByRoomIDAndSessionIDStmt, err = db.Prepare(selectKeysByRoomIDAndSessionIDSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,11 +125,35 @@ func (s *keyBackupStatements) updateBackupKey(
|
||||||
func (s *keyBackupStatements) selectKeys(
|
func (s *keyBackupStatements) selectKeys(
|
||||||
ctx context.Context, txn *sql.Tx, userID, version string,
|
ctx context.Context, txn *sql.Tx, userID, version string,
|
||||||
) (map[string]map[string]api.KeyBackupSession, error) {
|
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||||
result := make(map[string]map[string]api.KeyBackupSession)
|
|
||||||
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
|
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return unpackKeys(ctx, rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyBackupStatements) selectKeysByRoomID(
|
||||||
|
ctx context.Context, txn *sql.Tx, userID, version, roomID string,
|
||||||
|
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||||
|
rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return unpackKeys(ctx, rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
|
||||||
|
ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
|
||||||
|
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||||
|
rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return unpackKeys(ctx, rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func unpackKeys(ctx context.Context, rows *sql.Rows) (map[string]map[string]api.KeyBackupSession, error) {
|
||||||
|
result := make(map[string]map[string]api.KeyBackupSession)
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed")
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var key api.InternalKeyBackupSession
|
var key api.InternalKeyBackupSession
|
||||||
|
|
|
@ -67,7 +67,6 @@ type keyBackupVersionStatements struct {
|
||||||
updateKeyBackupETagStmt *sql.Stmt
|
updateKeyBackupETagStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint:unused
|
|
||||||
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
|
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
|
||||||
_, err = db.Exec(keyBackupVersionTableSchema)
|
_, err = db.Exec(keyBackupVersionTableSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -96,13 +96,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
||||||
if err = d.openIDTokens.prepare(db, serverName); err != nil {
|
if err = d.openIDTokens.prepare(db, serverName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
/*
|
if err = d.keyBackupVersions.prepare(db); err != nil {
|
||||||
if err = d.keyBackupVersions.prepare(db); err != nil {
|
return nil, err
|
||||||
return nil, err
|
}
|
||||||
}
|
if err = d.keyBackups.prepare(db); err != nil {
|
||||||
if err = d.keyBackups.prepare(db); err != nil {
|
return nil, err
|
||||||
return nil, err
|
}
|
||||||
} */
|
|
||||||
|
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
@ -418,6 +417,37 @@ func (d *Database) GetKeyBackup(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetBackupKeys(
|
||||||
|
ctx context.Context, version, userID, filterRoomID, filterSessionID string,
|
||||||
|
) (result map[string]map[string]api.KeyBackupSession, err error) {
|
||||||
|
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
if filterSessionID != "" {
|
||||||
|
result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if filterRoomID != "" {
|
||||||
|
result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
result, err = d.keyBackups.selectKeys(ctx, txn, userID, version)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) CountBackupKeys(
|
||||||
|
ctx context.Context, version, userID string,
|
||||||
|
) (count int64, err error) {
|
||||||
|
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// nolint:nakedret
|
// nolint:nakedret
|
||||||
func (d *Database) UpsertBackupKeys(
|
func (d *Database) UpsertBackupKeys(
|
||||||
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
|
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
|
||||||
|
@ -445,7 +475,7 @@ func (d *Database) UpsertBackupKeys(
|
||||||
if existingRoom != nil {
|
if existingRoom != nil {
|
||||||
existingSession, ok := existingRoom[newKey.SessionID]
|
existingSession, ok := existingRoom[newKey.SessionID]
|
||||||
if ok {
|
if ok {
|
||||||
if shouldReplaceRoomKey(existingSession, newKey.KeyBackupSession) {
|
if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
|
||||||
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
|
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
|
||||||
changed = true
|
changed = true
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -489,22 +519,3 @@ func (d *Database) UpsertBackupKeys(
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO FIXME XXX : This logic really shouldn't live in the storage layer, but I don't know where else is sensible which won't
|
|
||||||
// create circular import loops
|
|
||||||
func shouldReplaceRoomKey(existing, uploaded api.KeyBackupSession) bool {
|
|
||||||
// https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
|
|
||||||
// "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true"
|
|
||||||
if uploaded.IsVerified && !existing.IsVerified {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// "if they have the same values for is_verified, then it will keep the key with a lower first_message_index"
|
|
||||||
if uploaded.FirstMessageIndex < existing.FirstMessageIndex {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count"
|
|
||||||
if uploaded.ForwardedCount < existing.ForwardedCount {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
|
@ -35,7 +35,8 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
|
||||||
is_verified BOOLEAN NOT NULL,
|
is_verified BOOLEAN NOT NULL,
|
||||||
session_data TEXT NOT NULL
|
session_data TEXT NOT NULL
|
||||||
);
|
);
|
||||||
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id);
|
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertBackupKeySQL = "" +
|
const insertBackupKeySQL = "" +
|
||||||
|
@ -53,14 +54,23 @@ const selectKeysSQL = "" +
|
||||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||||
"WHERE user_id = $1 AND version = $2"
|
"WHERE user_id = $1 AND version = $2"
|
||||||
|
|
||||||
|
const selectKeysByRoomIDSQL = "" +
|
||||||
|
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||||
|
"WHERE user_id = $1 AND version = $2 AND room_id = $3"
|
||||||
|
|
||||||
|
const selectKeysByRoomIDAndSessionIDSQL = "" +
|
||||||
|
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
|
||||||
|
"WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
|
||||||
|
|
||||||
type keyBackupStatements struct {
|
type keyBackupStatements struct {
|
||||||
insertBackupKeyStmt *sql.Stmt
|
insertBackupKeyStmt *sql.Stmt
|
||||||
updateBackupKeyStmt *sql.Stmt
|
updateBackupKeyStmt *sql.Stmt
|
||||||
countKeysStmt *sql.Stmt
|
countKeysStmt *sql.Stmt
|
||||||
selectKeysStmt *sql.Stmt
|
selectKeysStmt *sql.Stmt
|
||||||
|
selectKeysByRoomIDStmt *sql.Stmt
|
||||||
|
selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint:unused
|
|
||||||
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
|
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
|
||||||
_, err = db.Exec(keyBackupTableSchema)
|
_, err = db.Exec(keyBackupTableSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -78,6 +88,12 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
|
||||||
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
|
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.selectKeysByRoomIDStmt, err = db.Prepare(selectKeysByRoomIDSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectKeysByRoomIDAndSessionIDStmt, err = db.Prepare(selectKeysByRoomIDAndSessionIDSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,11 +125,35 @@ func (s *keyBackupStatements) updateBackupKey(
|
||||||
func (s *keyBackupStatements) selectKeys(
|
func (s *keyBackupStatements) selectKeys(
|
||||||
ctx context.Context, txn *sql.Tx, userID, version string,
|
ctx context.Context, txn *sql.Tx, userID, version string,
|
||||||
) (map[string]map[string]api.KeyBackupSession, error) {
|
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||||
result := make(map[string]map[string]api.KeyBackupSession)
|
|
||||||
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
|
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return unpackKeys(ctx, rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyBackupStatements) selectKeysByRoomID(
|
||||||
|
ctx context.Context, txn *sql.Tx, userID, version, roomID string,
|
||||||
|
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||||
|
rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return unpackKeys(ctx, rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
|
||||||
|
ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
|
||||||
|
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||||
|
rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return unpackKeys(ctx, rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
func unpackKeys(ctx context.Context, rows *sql.Rows) (map[string]map[string]api.KeyBackupSession, error) {
|
||||||
|
result := make(map[string]map[string]api.KeyBackupSession)
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed")
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var key api.InternalKeyBackupSession
|
var key api.InternalKeyBackupSession
|
||||||
|
|
|
@ -65,7 +65,6 @@ type keyBackupVersionStatements struct {
|
||||||
updateKeyBackupETagStmt *sql.Stmt
|
updateKeyBackupETagStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint:unused
|
|
||||||
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
|
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
|
||||||
_, err = db.Exec(keyBackupVersionTableSchema)
|
_, err = db.Exec(keyBackupVersionTableSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -100,13 +100,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
||||||
if err = d.openIDTokens.prepare(db, serverName); err != nil {
|
if err = d.openIDTokens.prepare(db, serverName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
/*
|
if err = d.keyBackupVersions.prepare(db); err != nil {
|
||||||
if err = d.keyBackupVersions.prepare(db); err != nil {
|
return nil, err
|
||||||
return nil, err
|
}
|
||||||
}
|
if err = d.keyBackups.prepare(db); err != nil {
|
||||||
if err = d.keyBackups.prepare(db); err != nil {
|
return nil, err
|
||||||
return nil, err
|
}
|
||||||
} */
|
|
||||||
|
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
@ -459,6 +458,37 @@ func (d *Database) GetKeyBackup(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetBackupKeys(
|
||||||
|
ctx context.Context, version, userID, filterRoomID, filterSessionID string,
|
||||||
|
) (result map[string]map[string]api.KeyBackupSession, err error) {
|
||||||
|
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
if filterSessionID != "" {
|
||||||
|
result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if filterRoomID != "" {
|
||||||
|
result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
result, err = d.keyBackups.selectKeys(ctx, txn, userID, version)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) CountBackupKeys(
|
||||||
|
ctx context.Context, version, userID string,
|
||||||
|
) (count int64, err error) {
|
||||||
|
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// nolint:nakedret
|
// nolint:nakedret
|
||||||
func (d *Database) UpsertBackupKeys(
|
func (d *Database) UpsertBackupKeys(
|
||||||
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
|
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
|
||||||
|
@ -486,7 +516,7 @@ func (d *Database) UpsertBackupKeys(
|
||||||
if existingRoom != nil {
|
if existingRoom != nil {
|
||||||
existingSession, ok := existingRoom[newKey.SessionID]
|
existingSession, ok := existingRoom[newKey.SessionID]
|
||||||
if ok {
|
if ok {
|
||||||
if shouldReplaceRoomKey(existingSession, newKey.KeyBackupSession) {
|
if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
|
||||||
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
|
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
|
||||||
changed = true
|
changed = true
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -531,22 +561,3 @@ func (d *Database) UpsertBackupKeys(
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO FIXME XXX : This logic really shouldn't live in the storage layer, but I don't know where else is sensible which won't
|
|
||||||
// create circular import loops
|
|
||||||
func shouldReplaceRoomKey(existing, uploaded api.KeyBackupSession) bool {
|
|
||||||
// https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
|
|
||||||
// "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true"
|
|
||||||
if uploaded.IsVerified && !existing.IsVerified {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// "if they have the same values for is_verified, then it will keep the key with a lower first_message_index"
|
|
||||||
if uploaded.FirstMessageIndex < existing.FirstMessageIndex {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count"
|
|
||||||
if uploaded.ForwardedCount < existing.ForwardedCount {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in a new issue