mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-27 07:28:27 +00:00
Merge both user API databases into one (#2186)
* Merge user API databases into one * Remove DeviceDatabase from config * Fix tests * Try that again * Clean up keyserver device keys when the devices no longer exist in the user API * Tweak ordering * Fix UserExists flag, device check * Allow including empty entries so we can clean them up * Remove logging
This commit is contained in:
parent
0a7dea4450
commit
153bfbbea5
76 changed files with 727 additions and 899 deletions
|
@ -23,7 +23,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ func RetrieveUserProfile(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
userID string,
|
userID string,
|
||||||
asAPI AppServiceQueryAPI,
|
asAPI AppServiceQueryAPI,
|
||||||
accountDB accounts.Database,
|
accountDB userdb.Database,
|
||||||
) (*authtypes.Profile, error) {
|
) (*authtypes.Profile, error) {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -283,8 +283,6 @@ func (m *DendriteMonolith) Start() {
|
||||||
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
|
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
|
||||||
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/%s", m.StorageDirectory, prefix))
|
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/%s", m.StorageDirectory, prefix))
|
||||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-account.db", m.StorageDirectory, prefix))
|
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-account.db", m.StorageDirectory, prefix))
|
||||||
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-device.db", m.StorageDirectory, prefix))
|
|
||||||
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-mediaapi.db", m.CacheDirectory, prefix))
|
|
||||||
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-syncapi.db", m.StorageDirectory, prefix))
|
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-syncapi.db", m.StorageDirectory, prefix))
|
||||||
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix))
|
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-roomserver.db", m.StorageDirectory, prefix))
|
||||||
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix))
|
cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/%s-keyserver.db", m.StorageDirectory, prefix))
|
||||||
|
|
|
@ -88,7 +88,6 @@ func (m *DendriteMonolith) Start() {
|
||||||
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
|
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
|
||||||
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", m.StorageDirectory))
|
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", m.StorageDirectory))
|
||||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-account.db", m.StorageDirectory))
|
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-account.db", m.StorageDirectory))
|
||||||
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-device.db", m.StorageDirectory))
|
|
||||||
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory))
|
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-mediaapi.db", m.StorageDirectory))
|
||||||
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-syncapi.db", m.StorageDirectory))
|
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-syncapi.db", m.StorageDirectory))
|
||||||
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-roomserver.db", m.StorageDirectory))
|
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s/dendrite-p2p-roomserver.db", m.StorageDirectory))
|
||||||
|
|
|
@ -28,7 +28,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ func AddPublicRoutes(
|
||||||
router *mux.Router,
|
router *mux.Router,
|
||||||
synapseAdminRouter *mux.Router,
|
synapseAdminRouter *mux.Router,
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
accountsDB accounts.Database,
|
accountsDB userdb.Database,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||||
eduInputAPI eduServerAPI.EDUServerInputAPI,
|
eduInputAPI eduServerAPI.EDUServerInputAPI,
|
||||||
|
|
|
@ -30,7 +30,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
@ -137,7 +137,7 @@ type fledglingEvent struct {
|
||||||
func CreateRoom(
|
func CreateRoom(
|
||||||
req *http.Request, device *api.Device,
|
req *http.Request, device *api.Device,
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
|
accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||||
asAPI appserviceAPI.AppServiceQueryAPI,
|
asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
// TODO (#267): Check room ID doesn't clash with an existing one, and we
|
// TODO (#267): Check room ID doesn't clash with an existing one, and we
|
||||||
|
@ -151,7 +151,7 @@ func CreateRoom(
|
||||||
func createRoom(
|
func createRoom(
|
||||||
req *http.Request, device *api.Device,
|
req *http.Request, device *api.Device,
|
||||||
cfg *config.ClientAPI, roomID string,
|
cfg *config.ClientAPI, roomID string,
|
||||||
accountDB accounts.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
|
accountDB userdb.Database, rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||||
asAPI appserviceAPI.AppServiceQueryAPI,
|
asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
logger := util.GetLogger(req.Context())
|
logger := util.GetLogger(req.Context())
|
||||||
|
|
|
@ -23,7 +23,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
@ -32,7 +32,7 @@ func JoinRoomByIDOrAlias(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
device *api.Device,
|
device *api.Device,
|
||||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||||
accountDB accounts.Database,
|
accountDB userdb.Database,
|
||||||
roomIDOrAlias string,
|
roomIDOrAlias string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
// Prepare to ask the roomserver to perform the room join.
|
// Prepare to ask the roomserver to perform the room join.
|
||||||
|
|
|
@ -24,7 +24,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -36,7 +36,7 @@ type crossSigningRequest struct {
|
||||||
func UploadCrossSigningDeviceKeys(
|
func UploadCrossSigningDeviceKeys(
|
||||||
req *http.Request, userInteractiveAuth *auth.UserInteractive,
|
req *http.Request, userInteractiveAuth *auth.UserInteractive,
|
||||||
keyserverAPI api.KeyInternalAPI, device *userapi.Device,
|
keyserverAPI api.KeyInternalAPI, device *userapi.Device,
|
||||||
accountDB accounts.Database, cfg *config.ClientAPI,
|
accountDB userdb.Database, cfg *config.ClientAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
uploadReq := &crossSigningRequest{}
|
uploadReq := &crossSigningRequest{}
|
||||||
uploadRes := &api.PerformUploadDeviceKeysResponse{}
|
uploadRes := &api.PerformUploadDeviceKeysResponse{}
|
||||||
|
|
|
@ -23,7 +23,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
@ -54,7 +54,7 @@ func passwordLogin() flows {
|
||||||
|
|
||||||
// Login implements GET and POST /login
|
// Login implements GET and POST /login
|
||||||
func Login(
|
func Login(
|
||||||
req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI,
|
req *http.Request, accountDB userdb.Database, userAPI userapi.UserInternalAPI,
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
if req.Method == http.MethodGet {
|
if req.Method == http.MethodGet {
|
||||||
|
|
|
@ -30,7 +30,7 @@ import (
|
||||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
|
@ -39,7 +39,7 @@ import (
|
||||||
var errMissingUserID = errors.New("'user_id' must be supplied")
|
var errMissingUserID = errors.New("'user_id' must be supplied")
|
||||||
|
|
||||||
func SendBan(
|
func SendBan(
|
||||||
req *http.Request, accountDB accounts.Database, device *userapi.Device,
|
req *http.Request, accountDB userdb.Database, device *userapi.Device,
|
||||||
roomID string, cfg *config.ClientAPI,
|
roomID string, cfg *config.ClientAPI,
|
||||||
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
|
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
@ -81,7 +81,7 @@ func SendBan(
|
||||||
return sendMembership(req.Context(), accountDB, device, roomID, "ban", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI)
|
return sendMembership(req.Context(), accountDB, device, roomID, "ban", body.Reason, cfg, body.UserID, evTime, roomVer, rsAPI, asAPI)
|
||||||
}
|
}
|
||||||
|
|
||||||
func sendMembership(ctx context.Context, accountDB accounts.Database, device *userapi.Device,
|
func sendMembership(ctx context.Context, accountDB userdb.Database, device *userapi.Device,
|
||||||
roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time,
|
roomID, membership, reason string, cfg *config.ClientAPI, targetUserID string, evTime time.Time,
|
||||||
roomVer gomatrixserverlib.RoomVersion,
|
roomVer gomatrixserverlib.RoomVersion,
|
||||||
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI) util.JSONResponse {
|
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI) util.JSONResponse {
|
||||||
|
@ -125,7 +125,7 @@ func sendMembership(ctx context.Context, accountDB accounts.Database, device *us
|
||||||
}
|
}
|
||||||
|
|
||||||
func SendKick(
|
func SendKick(
|
||||||
req *http.Request, accountDB accounts.Database, device *userapi.Device,
|
req *http.Request, accountDB userdb.Database, device *userapi.Device,
|
||||||
roomID string, cfg *config.ClientAPI,
|
roomID string, cfg *config.ClientAPI,
|
||||||
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
|
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
@ -165,7 +165,7 @@ func SendKick(
|
||||||
}
|
}
|
||||||
|
|
||||||
func SendUnban(
|
func SendUnban(
|
||||||
req *http.Request, accountDB accounts.Database, device *userapi.Device,
|
req *http.Request, accountDB userdb.Database, device *userapi.Device,
|
||||||
roomID string, cfg *config.ClientAPI,
|
roomID string, cfg *config.ClientAPI,
|
||||||
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
|
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
@ -200,7 +200,7 @@ func SendUnban(
|
||||||
}
|
}
|
||||||
|
|
||||||
func SendInvite(
|
func SendInvite(
|
||||||
req *http.Request, accountDB accounts.Database, device *userapi.Device,
|
req *http.Request, accountDB userdb.Database, device *userapi.Device,
|
||||||
roomID string, cfg *config.ClientAPI,
|
roomID string, cfg *config.ClientAPI,
|
||||||
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
|
rsAPI roomserverAPI.RoomserverInternalAPI, asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
@ -271,7 +271,7 @@ func SendInvite(
|
||||||
|
|
||||||
func buildMembershipEvent(
|
func buildMembershipEvent(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
targetUserID, reason string, accountDB accounts.Database,
|
targetUserID, reason string, accountDB userdb.Database,
|
||||||
device *userapi.Device,
|
device *userapi.Device,
|
||||||
membership, roomID string, isDirect bool,
|
membership, roomID string, isDirect bool,
|
||||||
cfg *config.ClientAPI, evTime time.Time,
|
cfg *config.ClientAPI, evTime time.Time,
|
||||||
|
@ -312,7 +312,7 @@ func loadProfile(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
userID string,
|
userID string,
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
accountDB accounts.Database,
|
accountDB userdb.Database,
|
||||||
asAPI appserviceAPI.AppServiceQueryAPI,
|
asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
) (*authtypes.Profile, error) {
|
) (*authtypes.Profile, error) {
|
||||||
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
|
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
|
@ -366,7 +366,7 @@ func checkAndProcessThreepid(
|
||||||
body *threepid.MembershipRequest,
|
body *threepid.MembershipRequest,
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||||
accountDB accounts.Database,
|
accountDB userdb.Database,
|
||||||
roomID string,
|
roomID string,
|
||||||
evTime time.Time,
|
evTime time.Time,
|
||||||
) (inviteStored bool, errRes *util.JSONResponse) {
|
) (inviteStored bool, errRes *util.JSONResponse) {
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
@ -29,7 +29,7 @@ type newPasswordAuth struct {
|
||||||
func Password(
|
func Password(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
userAPI api.UserInternalAPI,
|
userAPI api.UserInternalAPI,
|
||||||
accountDB accounts.Database,
|
accountDB userdb.Database,
|
||||||
device *api.Device,
|
device *api.Device,
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
|
|
@ -19,7 +19,7 @@ import (
|
||||||
|
|
||||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
@ -28,7 +28,7 @@ func PeekRoomByIDOrAlias(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
device *api.Device,
|
device *api.Device,
|
||||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||||
accountDB accounts.Database,
|
accountDB userdb.Database,
|
||||||
roomIDOrAlias string,
|
roomIDOrAlias string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
// if this is a remote roomIDOrAlias, we have to ask the roomserver (or federation sender?) to
|
// if this is a remote roomIDOrAlias, we have to ask the roomserver (or federation sender?) to
|
||||||
|
@ -82,7 +82,7 @@ func UnpeekRoomByID(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
device *api.Device,
|
device *api.Device,
|
||||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||||
accountDB accounts.Database,
|
accountDB userdb.Database,
|
||||||
roomID string,
|
roomID string,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
unpeekReq := roomserverAPI.PerformUnpeekRequest{
|
unpeekReq := roomserverAPI.PerformUnpeekRequest{
|
||||||
|
|
|
@ -27,7 +27,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrix"
|
"github.com/matrix-org/gomatrix"
|
||||||
|
@ -36,7 +36,7 @@ import (
|
||||||
|
|
||||||
// GetProfile implements GET /profile/{userID}
|
// GetProfile implements GET /profile/{userID}
|
||||||
func GetProfile(
|
func GetProfile(
|
||||||
req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI,
|
req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
|
||||||
userID string,
|
userID string,
|
||||||
asAPI appserviceAPI.AppServiceQueryAPI,
|
asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
|
@ -65,7 +65,7 @@ func GetProfile(
|
||||||
|
|
||||||
// GetAvatarURL implements GET /profile/{userID}/avatar_url
|
// GetAvatarURL implements GET /profile/{userID}/avatar_url
|
||||||
func GetAvatarURL(
|
func GetAvatarURL(
|
||||||
req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI,
|
req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
|
||||||
userID string, asAPI appserviceAPI.AppServiceQueryAPI,
|
userID string, asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
@ -92,7 +92,7 @@ func GetAvatarURL(
|
||||||
|
|
||||||
// SetAvatarURL implements PUT /profile/{userID}/avatar_url
|
// SetAvatarURL implements PUT /profile/{userID}/avatar_url
|
||||||
func SetAvatarURL(
|
func SetAvatarURL(
|
||||||
req *http.Request, accountDB accounts.Database,
|
req *http.Request, accountDB userdb.Database,
|
||||||
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI,
|
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
if userID != device.UserID {
|
if userID != device.UserID {
|
||||||
|
@ -182,7 +182,7 @@ func SetAvatarURL(
|
||||||
|
|
||||||
// GetDisplayName implements GET /profile/{userID}/displayname
|
// GetDisplayName implements GET /profile/{userID}/displayname
|
||||||
func GetDisplayName(
|
func GetDisplayName(
|
||||||
req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI,
|
req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI,
|
||||||
userID string, asAPI appserviceAPI.AppServiceQueryAPI,
|
userID string, asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
@ -209,7 +209,7 @@ func GetDisplayName(
|
||||||
|
|
||||||
// SetDisplayName implements PUT /profile/{userID}/displayname
|
// SetDisplayName implements PUT /profile/{userID}/displayname
|
||||||
func SetDisplayName(
|
func SetDisplayName(
|
||||||
req *http.Request, accountDB accounts.Database,
|
req *http.Request, accountDB userdb.Database,
|
||||||
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI,
|
device *userapi.Device, userID string, cfg *config.ClientAPI, rsAPI api.RoomserverInternalAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
if userID != device.UserID {
|
if userID != device.UserID {
|
||||||
|
@ -302,7 +302,7 @@ func SetDisplayName(
|
||||||
// Returns an error when something goes wrong or specifically
|
// Returns an error when something goes wrong or specifically
|
||||||
// eventutil.ErrProfileNoExists when the profile doesn't exist.
|
// eventutil.ErrProfileNoExists when the profile doesn't exist.
|
||||||
func getProfile(
|
func getProfile(
|
||||||
ctx context.Context, accountDB accounts.Database, cfg *config.ClientAPI,
|
ctx context.Context, accountDB userdb.Database, cfg *config.ClientAPI,
|
||||||
userID string,
|
userID string,
|
||||||
asAPI appserviceAPI.AppServiceQueryAPI,
|
asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
|
|
|
@ -44,7 +44,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -448,7 +448,7 @@ func validateApplicationService(
|
||||||
func Register(
|
func Register(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
userAPI userapi.UserInternalAPI,
|
userAPI userapi.UserInternalAPI,
|
||||||
accountDB accounts.Database,
|
accountDB userdb.Database,
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
var r registerRequest
|
var r registerRequest
|
||||||
|
@ -899,7 +899,7 @@ type availableResponse struct {
|
||||||
func RegisterAvailable(
|
func RegisterAvailable(
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
accountDB accounts.Database,
|
accountDB userdb.Database,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
username := req.URL.Query().Get("username")
|
username := req.URL.Query().Get("username")
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,7 @@ import (
|
||||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
@ -51,7 +51,7 @@ func Setup(
|
||||||
eduAPI eduServerAPI.EDUServerInputAPI,
|
eduAPI eduServerAPI.EDUServerInputAPI,
|
||||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||||
asAPI appserviceAPI.AppServiceQueryAPI,
|
asAPI appserviceAPI.AppServiceQueryAPI,
|
||||||
accountDB accounts.Database,
|
accountDB userdb.Database,
|
||||||
userAPI userapi.UserInternalAPI,
|
userAPI userapi.UserInternalAPI,
|
||||||
federation *gomatrixserverlib.FederationClient,
|
federation *gomatrixserverlib.FederationClient,
|
||||||
syncProducer *producers.SyncAPIProducer,
|
syncProducer *producers.SyncAPIProducer,
|
||||||
|
|
|
@ -20,7 +20,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/eduserver/api"
|
"github.com/matrix-org/dendrite/eduserver/api"
|
||||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ type typingContentJSON struct {
|
||||||
// sends the typing events to client API typingProducer
|
// sends the typing events to client API typingProducer
|
||||||
func SendTyping(
|
func SendTyping(
|
||||||
req *http.Request, device *userapi.Device, roomID string,
|
req *http.Request, device *userapi.Device, roomID string,
|
||||||
userID string, accountDB accounts.Database,
|
userID string, accountDB userdb.Database,
|
||||||
eduAPI api.EDUServerInputAPI,
|
eduAPI api.EDUServerInputAPI,
|
||||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
|
|
|
@ -23,7 +23,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/clientapi/threepid"
|
"github.com/matrix-org/dendrite/clientapi/threepid"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
|
@ -40,7 +40,7 @@ type threePIDsResponse struct {
|
||||||
// RequestEmailToken implements:
|
// RequestEmailToken implements:
|
||||||
// POST /account/3pid/email/requestToken
|
// POST /account/3pid/email/requestToken
|
||||||
// POST /register/email/requestToken
|
// POST /register/email/requestToken
|
||||||
func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *config.ClientAPI) util.JSONResponse {
|
func RequestEmailToken(req *http.Request, accountDB userdb.Database, cfg *config.ClientAPI) util.JSONResponse {
|
||||||
var body threepid.EmailAssociationRequest
|
var body threepid.EmailAssociationRequest
|
||||||
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
|
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
|
||||||
return *reqErr
|
return *reqErr
|
||||||
|
@ -61,7 +61,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
JSON: jsonerror.MatrixError{
|
JSON: jsonerror.MatrixError{
|
||||||
ErrCode: "M_THREEPID_IN_USE",
|
ErrCode: "M_THREEPID_IN_USE",
|
||||||
Err: accounts.Err3PIDInUse.Error(),
|
Err: userdb.Err3PIDInUse.Error(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -85,7 +85,7 @@ func RequestEmailToken(req *http.Request, accountDB accounts.Database, cfg *conf
|
||||||
|
|
||||||
// CheckAndSave3PIDAssociation implements POST /account/3pid
|
// CheckAndSave3PIDAssociation implements POST /account/3pid
|
||||||
func CheckAndSave3PIDAssociation(
|
func CheckAndSave3PIDAssociation(
|
||||||
req *http.Request, accountDB accounts.Database, device *api.Device,
|
req *http.Request, accountDB userdb.Database, device *api.Device,
|
||||||
cfg *config.ClientAPI,
|
cfg *config.ClientAPI,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
var body threepid.EmailAssociationCheckRequest
|
var body threepid.EmailAssociationCheckRequest
|
||||||
|
@ -149,7 +149,7 @@ func CheckAndSave3PIDAssociation(
|
||||||
|
|
||||||
// GetAssociated3PIDs implements GET /account/3pid
|
// GetAssociated3PIDs implements GET /account/3pid
|
||||||
func GetAssociated3PIDs(
|
func GetAssociated3PIDs(
|
||||||
req *http.Request, accountDB accounts.Database, device *api.Device,
|
req *http.Request, accountDB userdb.Database, device *api.Device,
|
||||||
) util.JSONResponse {
|
) util.JSONResponse {
|
||||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -170,7 +170,7 @@ func GetAssociated3PIDs(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forget3PID implements POST /account/3pid/delete
|
// Forget3PID implements POST /account/3pid/delete
|
||||||
func Forget3PID(req *http.Request, accountDB accounts.Database) util.JSONResponse {
|
func Forget3PID(req *http.Request, accountDB userdb.Database) util.JSONResponse {
|
||||||
var body authtypes.ThreePID
|
var body authtypes.ThreePID
|
||||||
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
|
if reqErr := httputil.UnmarshalJSONRequest(req, &body); reqErr != nil {
|
||||||
return *reqErr
|
return *reqErr
|
||||||
|
|
|
@ -29,7 +29,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ var (
|
||||||
func CheckAndProcessInvite(
|
func CheckAndProcessInvite(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
device *userapi.Device, body *MembershipRequest, cfg *config.ClientAPI,
|
device *userapi.Device, body *MembershipRequest, cfg *config.ClientAPI,
|
||||||
rsAPI api.RoomserverInternalAPI, db accounts.Database,
|
rsAPI api.RoomserverInternalAPI, db userdb.Database,
|
||||||
roomID string,
|
roomID string,
|
||||||
evTime time.Time,
|
evTime time.Time,
|
||||||
) (inviteStoredOnIDServer bool, err error) {
|
) (inviteStoredOnIDServer bool, err error) {
|
||||||
|
@ -137,7 +137,7 @@ func CheckAndProcessInvite(
|
||||||
// Returns an error if a check or a request failed.
|
// Returns an error if a check or a request failed.
|
||||||
func queryIDServer(
|
func queryIDServer(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
db accounts.Database, cfg *config.ClientAPI, device *userapi.Device,
|
db userdb.Database, cfg *config.ClientAPI, device *userapi.Device,
|
||||||
body *MembershipRequest, roomID string,
|
body *MembershipRequest, roomID string,
|
||||||
) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) {
|
) (lookupRes *idServerLookupResponse, storeInviteRes *idServerStoreInviteResponse, err error) {
|
||||||
if err = isTrusted(body.IDServer, cfg); err != nil {
|
if err = isTrusted(body.IDServer, cfg); err != nil {
|
||||||
|
@ -206,7 +206,7 @@ func queryIDServerLookup(ctx context.Context, body *MembershipRequest) (*idServe
|
||||||
// Returns an error if the request failed to send or if the response couldn't be parsed.
|
// Returns an error if the request failed to send or if the response couldn't be parsed.
|
||||||
func queryIDServerStoreInvite(
|
func queryIDServerStoreInvite(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
db accounts.Database, cfg *config.ClientAPI, device *userapi.Device,
|
db userdb.Database, cfg *config.ClientAPI, device *userapi.Device,
|
||||||
body *MembershipRequest, roomID string,
|
body *MembershipRequest, roomID string,
|
||||||
) (*idServerStoreInviteResponse, error) {
|
) (*idServerStoreInviteResponse, error) {
|
||||||
// Retrieve the sender's profile to get their display name
|
// Retrieve the sender's profile to get their display name
|
||||||
|
|
|
@ -30,7 +30,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup"
|
"github.com/matrix-org/dendrite/setup"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
const usage = `Usage: %s
|
const usage = `Usage: %s
|
||||||
|
@ -77,9 +77,14 @@ func main() {
|
||||||
|
|
||||||
pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin)
|
pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin)
|
||||||
|
|
||||||
accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{
|
accountDB, err := userdb.NewDatabase(
|
||||||
ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString,
|
&config.DatabaseOptions{
|
||||||
}, cfg.Global.ServerName, bcrypt.DefaultCost, cfg.UserAPI.OpenIDTokenLifetimeMS)
|
ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString,
|
||||||
|
},
|
||||||
|
cfg.Global.ServerName, bcrypt.DefaultCost,
|
||||||
|
cfg.UserAPI.OpenIDTokenLifetimeMS,
|
||||||
|
api.DefaultLoginTokenLifetime,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Fatalln("Failed to connect to the database:", err.Error())
|
logrus.Fatalln("Failed to connect to the database:", err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -126,7 +126,6 @@ func main() {
|
||||||
cfg.FederationAPI.FederationMaxRetries = 6
|
cfg.FederationAPI.FederationMaxRetries = 6
|
||||||
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
|
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
|
||||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
|
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
|
||||||
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
|
|
||||||
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
|
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
|
||||||
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
|
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
|
||||||
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
|
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
|
||||||
|
|
|
@ -160,7 +160,6 @@ func main() {
|
||||||
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
|
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
|
||||||
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
|
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
|
||||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
|
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
|
||||||
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
|
|
||||||
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
|
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
|
||||||
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
|
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
|
||||||
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
|
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
|
||||||
|
|
|
@ -79,7 +79,6 @@ func main() {
|
||||||
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
|
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
|
||||||
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
|
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
|
||||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
|
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
|
||||||
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-device.db", *instanceName))
|
|
||||||
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
|
cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", *instanceName))
|
||||||
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
|
cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", *instanceName))
|
||||||
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
|
cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", *instanceName))
|
||||||
|
|
|
@ -164,7 +164,6 @@ func startup() {
|
||||||
cfg.Defaults(true)
|
cfg.Defaults(true)
|
||||||
cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db"
|
cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db"
|
||||||
cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db"
|
cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db"
|
||||||
cfg.UserAPI.DeviceDatabase.ConnectionString = "file:/idb/dendritejs_device.db"
|
|
||||||
cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db"
|
cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db"
|
||||||
cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db"
|
cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db"
|
||||||
cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db"
|
cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db"
|
||||||
|
|
|
@ -167,7 +167,6 @@ func main() {
|
||||||
cfg.Defaults(true)
|
cfg.Defaults(true)
|
||||||
cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db"
|
cfg.UserAPI.AccountDatabase.ConnectionString = "file:/idb/dendritejs_account.db"
|
||||||
cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db"
|
cfg.AppServiceAPI.Database.ConnectionString = "file:/idb/dendritejs_appservice.db"
|
||||||
cfg.UserAPI.DeviceDatabase.ConnectionString = "file:/idb/dendritejs_device.db"
|
|
||||||
cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db"
|
cfg.FederationAPI.Database.ConnectionString = "file:/idb/dendritejs_fedsender.db"
|
||||||
cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db"
|
cfg.MediaAPI.Database.ConnectionString = "file:/idb/dendritejs_mediaapi.db"
|
||||||
cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db"
|
cfg.RoomServer.Database.ConnectionString = "file:/idb/dendritejs_roomserver.db"
|
||||||
|
|
|
@ -32,7 +32,6 @@ func main() {
|
||||||
cfg.RoomServer.Database.ConnectionString = config.DataSource(*dbURI)
|
cfg.RoomServer.Database.ConnectionString = config.DataSource(*dbURI)
|
||||||
cfg.SyncAPI.Database.ConnectionString = config.DataSource(*dbURI)
|
cfg.SyncAPI.Database.ConnectionString = config.DataSource(*dbURI)
|
||||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(*dbURI)
|
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(*dbURI)
|
||||||
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(*dbURI)
|
|
||||||
}
|
}
|
||||||
cfg.Global.TrustedIDServers = []string{
|
cfg.Global.TrustedIDServers = []string{
|
||||||
"matrix.org",
|
"matrix.org",
|
||||||
|
|
|
@ -8,12 +8,11 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
pgaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas"
|
|
||||||
slaccounts "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas"
|
|
||||||
pgdevices "github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas"
|
|
||||||
sldevices "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
|
|
||||||
"github.com/pressly/goose"
|
"github.com/pressly/goose"
|
||||||
|
|
||||||
|
pgusers "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
|
||||||
|
slusers "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
|
||||||
|
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
@ -26,8 +25,7 @@ const (
|
||||||
RoomServer = "roomserver"
|
RoomServer = "roomserver"
|
||||||
SigningKeyServer = "signingkeyserver"
|
SigningKeyServer = "signingkeyserver"
|
||||||
SyncAPI = "syncapi"
|
SyncAPI = "syncapi"
|
||||||
UserAPIAccounts = "userapi_accounts"
|
UserAPI = "userapi"
|
||||||
UserAPIDevices = "userapi_devices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -35,7 +33,7 @@ var (
|
||||||
flags = flag.NewFlagSet("goose", flag.ExitOnError)
|
flags = flag.NewFlagSet("goose", flag.ExitOnError)
|
||||||
component = flags.String("component", "", "dendrite component name")
|
component = flags.String("component", "", "dendrite component name")
|
||||||
knownDBs = []string{
|
knownDBs = []string{
|
||||||
AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPIAccounts, UserAPIDevices,
|
AppService, FederationSender, KeyServer, MediaAPI, RoomServer, SigningKeyServer, SyncAPI, UserAPI,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -143,18 +141,14 @@ Commands:
|
||||||
|
|
||||||
func loadSQLiteDeltas(component string) {
|
func loadSQLiteDeltas(component string) {
|
||||||
switch component {
|
switch component {
|
||||||
case UserAPIAccounts:
|
case UserAPI:
|
||||||
slaccounts.LoadFromGoose()
|
slusers.LoadFromGoose()
|
||||||
case UserAPIDevices:
|
|
||||||
sldevices.LoadFromGoose()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadPostgresDeltas(component string) {
|
func loadPostgresDeltas(component string) {
|
||||||
switch component {
|
switch component {
|
||||||
case UserAPIAccounts:
|
case UserAPI:
|
||||||
pgaccounts.LoadFromGoose()
|
pgusers.LoadFromGoose()
|
||||||
case UserAPIDevices:
|
|
||||||
pgdevices.LoadFromGoose()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -95,7 +95,6 @@ func MakeConfig(configDir, kafkaURI, database, host string, startPort int) (*con
|
||||||
cfg.RoomServer.Database.ConnectionString = config.DataSource(database)
|
cfg.RoomServer.Database.ConnectionString = config.DataSource(database)
|
||||||
cfg.SyncAPI.Database.ConnectionString = config.DataSource(database)
|
cfg.SyncAPI.Database.ConnectionString = config.DataSource(database)
|
||||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(database)
|
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(database)
|
||||||
cfg.UserAPI.DeviceDatabase.ConnectionString = config.DataSource(database)
|
|
||||||
|
|
||||||
cfg.AppServiceAPI.InternalAPI.Listen = assignAddress()
|
cfg.AppServiceAPI.InternalAPI.Listen = assignAddress()
|
||||||
cfg.EDUServer.InternalAPI.Listen = assignAddress()
|
cfg.EDUServer.InternalAPI.Listen = assignAddress()
|
||||||
|
|
|
@ -198,7 +198,7 @@ func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOne
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) {
|
func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) {
|
||||||
msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil)
|
msgs, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = &api.KeyError{
|
res.Error = &api.KeyError{
|
||||||
Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
|
Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
|
||||||
|
@ -244,7 +244,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
||||||
domain := string(serverName)
|
domain := string(serverName)
|
||||||
// query local devices
|
// query local devices
|
||||||
if serverName == a.ThisServer {
|
if serverName == a.ThisServer {
|
||||||
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
|
deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = &api.KeyError{
|
res.Error = &api.KeyError{
|
||||||
Err: fmt.Sprintf("failed to query local device keys: %s", err),
|
Err: fmt.Sprintf("failed to query local device keys: %s", err),
|
||||||
|
@ -525,7 +525,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
|
||||||
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
||||||
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
|
ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
|
||||||
) error {
|
) error {
|
||||||
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
|
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false)
|
||||||
// if we can't query the db or there are fewer keys than requested, fetch from remote.
|
// if we can't query the db or there are fewer keys than requested, fetch from remote.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
|
return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
|
||||||
|
@ -554,10 +554,60 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||||
|
// get a list of devices from the user API that actually exist, as
|
||||||
|
// we won't store keys for devices that don't exist
|
||||||
|
uapidevices := &userapi.QueryDevicesResponse{}
|
||||||
|
if err := a.UserAPI.QueryDevices(ctx, &userapi.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil {
|
||||||
|
res.Error = &api.KeyError{
|
||||||
|
Err: err.Error(),
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !uapidevices.UserExists {
|
||||||
|
res.Error = &api.KeyError{
|
||||||
|
Err: fmt.Sprintf("user %q does not exist", req.UserID),
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices))
|
||||||
|
for _, key := range uapidevices.Devices {
|
||||||
|
existingDeviceMap[key.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get all of the user existing device keys so we can check for changes.
|
||||||
|
existingKeys, err := a.DB.DeviceKeysForUser(ctx, req.UserID, nil, true)
|
||||||
|
if err != nil {
|
||||||
|
res.Error = &api.KeyError{
|
||||||
|
Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Work out whether we have device keys in the keyserver for devices that
|
||||||
|
// no longer exist in the user API. This is mostly an exercise to ensure
|
||||||
|
// that we keep some integrity between the two.
|
||||||
|
var toClean []gomatrixserverlib.KeyID
|
||||||
|
for _, k := range existingKeys {
|
||||||
|
if _, ok := existingDeviceMap[k.DeviceID]; !ok {
|
||||||
|
toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(toClean) > 0 {
|
||||||
|
if err = a.DB.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil {
|
||||||
|
res.Error = &api.KeyError{
|
||||||
|
Err: fmt.Sprintf("failed to clean device keys: %s", err.Error()),
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logrus.WithField("user_id", req.UserID).Infof("Cleaned up %d stale keyserver device key entries", len(toClean))
|
||||||
|
}
|
||||||
|
|
||||||
var keysToStore []api.DeviceMessage
|
var keysToStore []api.DeviceMessage
|
||||||
// assert that the user ID / device ID are not lying for each key
|
// assert that the user ID / device ID are not lying for each key
|
||||||
for _, key := range req.DeviceKeys {
|
for _, key := range req.DeviceKeys {
|
||||||
_, serverName, err := gomatrixserverlib.SplitID('@', key.UserID)
|
var serverName gomatrixserverlib.ServerName
|
||||||
|
_, serverName, err = gomatrixserverlib.SplitID('@', key.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue // ignore invalid users
|
continue // ignore invalid users
|
||||||
}
|
}
|
||||||
|
@ -568,6 +618,11 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
|
||||||
keysToStore = append(keysToStore, key.WithStreamID(0))
|
keysToStore = append(keysToStore, key.WithStreamID(0))
|
||||||
continue // deleted keys don't need sanity checking
|
continue // deleted keys don't need sanity checking
|
||||||
}
|
}
|
||||||
|
// check that the device in question actually exists in the user
|
||||||
|
// API before we try and store a key for it
|
||||||
|
if _, ok := existingDeviceMap[key.DeviceID]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
|
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
|
||||||
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
|
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
|
||||||
if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
|
if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
|
||||||
|
@ -583,29 +638,12 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// get existing device keys so we can check for changes
|
|
||||||
existingKeys := make([]api.DeviceMessage, len(keysToStore))
|
|
||||||
for i := range keysToStore {
|
|
||||||
existingKeys[i] = api.DeviceMessage{
|
|
||||||
Type: api.TypeDeviceKeyUpdate,
|
|
||||||
DeviceKeys: &api.DeviceKeys{
|
|
||||||
UserID: keysToStore[i].UserID,
|
|
||||||
DeviceID: keysToStore[i].DeviceID,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil {
|
|
||||||
res.Error = &api.KeyError{
|
|
||||||
Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if req.OnlyDisplayNameUpdates {
|
if req.OnlyDisplayNameUpdates {
|
||||||
// add the display name field from keysToStore into existingKeys
|
// add the display name field from keysToStore into existingKeys
|
||||||
keysToStore = appendDisplayNames(existingKeys, keysToStore)
|
keysToStore = appendDisplayNames(existingKeys, keysToStore)
|
||||||
}
|
}
|
||||||
// store the device keys and emit changes
|
// store the device keys and emit changes
|
||||||
err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
|
err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = &api.KeyError{
|
res.Error = &api.KeyError{
|
||||||
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
|
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
|
||||||
|
|
|
@ -53,7 +53,7 @@ type Database interface {
|
||||||
|
|
||||||
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
|
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
|
||||||
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
|
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
|
||||||
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
|
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
|
||||||
|
|
||||||
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
|
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
|
||||||
// cross-signing signatures relating to that device.
|
// cross-signing signatures relating to that device.
|
||||||
|
|
|
@ -56,6 +56,9 @@ const selectDeviceKeysSQL = "" +
|
||||||
const selectBatchDeviceKeysSQL = "" +
|
const selectBatchDeviceKeysSQL = "" +
|
||||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
||||||
|
|
||||||
|
const selectBatchDeviceKeysWithEmptiesSQL = "" +
|
||||||
|
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
const selectMaxStreamForUserSQL = "" +
|
const selectMaxStreamForUserSQL = "" +
|
||||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
|
@ -69,14 +72,15 @@ const deleteAllDeviceKeysSQL = "" +
|
||||||
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
type deviceKeysStatements struct {
|
type deviceKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
upsertDeviceKeysStmt *sql.Stmt
|
upsertDeviceKeysStmt *sql.Stmt
|
||||||
selectDeviceKeysStmt *sql.Stmt
|
selectDeviceKeysStmt *sql.Stmt
|
||||||
selectBatchDeviceKeysStmt *sql.Stmt
|
selectBatchDeviceKeysStmt *sql.Stmt
|
||||||
selectMaxStreamForUserStmt *sql.Stmt
|
selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
|
||||||
countStreamIDsForUserStmt *sql.Stmt
|
selectMaxStreamForUserStmt *sql.Stmt
|
||||||
deleteDeviceKeysStmt *sql.Stmt
|
countStreamIDsForUserStmt *sql.Stmt
|
||||||
deleteAllDeviceKeysStmt *sql.Stmt
|
deleteDeviceKeysStmt *sql.Stmt
|
||||||
|
deleteAllDeviceKeysStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
|
@ -96,6 +100,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -180,8 +187,14 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
|
||||||
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
var stmt *sql.Stmt
|
||||||
|
if includeEmpty {
|
||||||
|
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
|
||||||
|
} else {
|
||||||
|
stmt = s.selectBatchDeviceKeysStmt
|
||||||
|
}
|
||||||
|
rows, err := stmt.QueryContext(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -108,8 +108,8 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
|
||||||
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
|
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
|
func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
|
||||||
|
|
|
@ -52,6 +52,9 @@ const selectDeviceKeysSQL = "" +
|
||||||
const selectBatchDeviceKeysSQL = "" +
|
const selectBatchDeviceKeysSQL = "" +
|
||||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
||||||
|
|
||||||
|
const selectBatchDeviceKeysWithEmptiesSQL = "" +
|
||||||
|
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
const selectMaxStreamForUserSQL = "" +
|
const selectMaxStreamForUserSQL = "" +
|
||||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
|
@ -65,13 +68,14 @@ const deleteAllDeviceKeysSQL = "" +
|
||||||
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||||
|
|
||||||
type deviceKeysStatements struct {
|
type deviceKeysStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
upsertDeviceKeysStmt *sql.Stmt
|
upsertDeviceKeysStmt *sql.Stmt
|
||||||
selectDeviceKeysStmt *sql.Stmt
|
selectDeviceKeysStmt *sql.Stmt
|
||||||
selectBatchDeviceKeysStmt *sql.Stmt
|
selectBatchDeviceKeysStmt *sql.Stmt
|
||||||
selectMaxStreamForUserStmt *sql.Stmt
|
selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
|
||||||
deleteDeviceKeysStmt *sql.Stmt
|
selectMaxStreamForUserStmt *sql.Stmt
|
||||||
deleteAllDeviceKeysStmt *sql.Stmt
|
deleteDeviceKeysStmt *sql.Stmt
|
||||||
|
deleteAllDeviceKeysStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
|
@ -91,6 +95,9 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if s.selectBatchDeviceKeysWithEmptiesStmt, err = db.Prepare(selectBatchDeviceKeysWithEmptiesSQL); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -113,12 +120,18 @@ func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
|
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
|
||||||
deviceIDMap := make(map[string]bool)
|
deviceIDMap := make(map[string]bool)
|
||||||
for _, d := range deviceIDs {
|
for _, d := range deviceIDs {
|
||||||
deviceIDMap[d] = true
|
deviceIDMap[d] = true
|
||||||
}
|
}
|
||||||
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
|
var stmt *sql.Stmt
|
||||||
|
if includeEmpty {
|
||||||
|
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
|
||||||
|
} else {
|
||||||
|
stmt = s.selectBatchDeviceKeysStmt
|
||||||
|
}
|
||||||
|
rows, err := stmt.QueryContext(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -173,7 +173,7 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Querying for device keys returns the latest stream IDs
|
// Querying for device keys returns the latest stream IDs
|
||||||
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"})
|
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"}, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("DeviceKeysForUser returned error: %s", err)
|
t.Fatalf("DeviceKeysForUser returned error: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,7 +38,7 @@ type DeviceKeys interface {
|
||||||
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
|
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
|
||||||
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
|
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
|
||||||
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
|
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
|
||||||
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
|
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
|
||||||
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
|
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
|
||||||
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
|
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,7 +38,7 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/setup/process"
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
|
@ -273,8 +273,14 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI {
|
||||||
|
|
||||||
// CreateAccountsDB creates a new instance of the accounts database. Should only
|
// CreateAccountsDB creates a new instance of the accounts database. Should only
|
||||||
// be called once per component.
|
// be called once per component.
|
||||||
func (b *BaseDendrite) CreateAccountsDB() accounts.Database {
|
func (b *BaseDendrite) CreateAccountsDB() userdb.Database {
|
||||||
db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost, b.Cfg.UserAPI.OpenIDTokenLifetimeMS)
|
db, err := userdb.NewDatabase(
|
||||||
|
&b.Cfg.UserAPI.AccountDatabase,
|
||||||
|
b.Cfg.Global.ServerName,
|
||||||
|
b.Cfg.UserAPI.BCryptCost,
|
||||||
|
b.Cfg.UserAPI.OpenIDTokenLifetimeMS,
|
||||||
|
userapi.DefaultLoginTokenLifetime,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Panicf("failed to connect to accounts db")
|
logrus.WithError(err).Panicf("failed to connect to accounts db")
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,9 +16,6 @@ type UserAPI struct {
|
||||||
// The Account database stores the login details and account information
|
// The Account database stores the login details and account information
|
||||||
// for local users. It is accessed by the UserAPI.
|
// for local users. It is accessed by the UserAPI.
|
||||||
AccountDatabase DatabaseOptions `yaml:"account_database"`
|
AccountDatabase DatabaseOptions `yaml:"account_database"`
|
||||||
// The Device database stores session information for the devices of logged
|
|
||||||
// in local users. It is accessed by the UserAPI.
|
|
||||||
DeviceDatabase DatabaseOptions `yaml:"device_database"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const DefaultOpenIDTokenLifetimeMS = 3600000 // 60 minutes
|
const DefaultOpenIDTokenLifetimeMS = 3600000 // 60 minutes
|
||||||
|
@ -27,10 +24,8 @@ func (c *UserAPI) Defaults(generate bool) {
|
||||||
c.InternalAPI.Listen = "http://localhost:7781"
|
c.InternalAPI.Listen = "http://localhost:7781"
|
||||||
c.InternalAPI.Connect = "http://localhost:7781"
|
c.InternalAPI.Connect = "http://localhost:7781"
|
||||||
c.AccountDatabase.Defaults(10)
|
c.AccountDatabase.Defaults(10)
|
||||||
c.DeviceDatabase.Defaults(10)
|
|
||||||
if generate {
|
if generate {
|
||||||
c.AccountDatabase.ConnectionString = "file:userapi_accounts.db"
|
c.AccountDatabase.ConnectionString = "file:userapi_accounts.db"
|
||||||
c.DeviceDatabase.ConnectionString = "file:userapi_devices.db"
|
|
||||||
}
|
}
|
||||||
c.BCryptCost = bcrypt.DefaultCost
|
c.BCryptCost = bcrypt.DefaultCost
|
||||||
c.OpenIDTokenLifetimeMS = DefaultOpenIDTokenLifetimeMS
|
c.OpenIDTokenLifetimeMS = DefaultOpenIDTokenLifetimeMS
|
||||||
|
@ -40,6 +35,5 @@ func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
||||||
checkURL(configErrs, "user_api.internal_api.listen", string(c.InternalAPI.Listen))
|
checkURL(configErrs, "user_api.internal_api.listen", string(c.InternalAPI.Listen))
|
||||||
checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect))
|
checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect))
|
||||||
checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString))
|
checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString))
|
||||||
checkNotEmpty(configErrs, "user_api.device_database.connection_string", string(c.DeviceDatabase.ConnectionString))
|
|
||||||
checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS)
|
checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS)
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,7 +30,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/process"
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
"github.com/matrix-org/dendrite/syncapi"
|
"github.com/matrix-org/dendrite/syncapi"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
userdb "github.com/matrix-org/dendrite/userapi/storage"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ import (
|
||||||
// all components of Dendrite, for use in monolith mode.
|
// all components of Dendrite, for use in monolith mode.
|
||||||
type Monolith struct {
|
type Monolith struct {
|
||||||
Config *config.Dendrite
|
Config *config.Dendrite
|
||||||
AccountDB accounts.Database
|
AccountDB userdb.Database
|
||||||
KeyRing *gomatrixserverlib.KeyRing
|
KeyRing *gomatrixserverlib.KeyRing
|
||||||
Client *gomatrixserverlib.Client
|
Client *gomatrixserverlib.Client
|
||||||
FedClient *gomatrixserverlib.FederationClient
|
FedClient *gomatrixserverlib.FederationClient
|
||||||
|
|
|
@ -19,6 +19,13 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// DefaultLoginTokenLifetime determines how old a valid token may be.
|
||||||
|
//
|
||||||
|
// NOTSPEC: The current spec says "SHOULD be limited to around five
|
||||||
|
// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low.
|
||||||
|
// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325).
|
||||||
|
const DefaultLoginTokenLifetime = 2 * time.Minute
|
||||||
|
|
||||||
type LoginTokenInternalAPI interface {
|
type LoginTokenInternalAPI interface {
|
||||||
// PerformLoginTokenCreation creates a new login token and associates it with the provided data.
|
// PerformLoginTokenCreation creates a new login token and associates it with the provided data.
|
||||||
PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error
|
PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error
|
||||||
|
|
|
@ -31,13 +31,11 @@ import (
|
||||||
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
"github.com/matrix-org/dendrite/userapi/storage"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/devices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserInternalAPI struct {
|
type UserInternalAPI struct {
|
||||||
AccountDB accounts.Database
|
DB storage.Database
|
||||||
DeviceDB devices.Database
|
|
||||||
ServerName gomatrixserverlib.ServerName
|
ServerName gomatrixserverlib.ServerName
|
||||||
// AppServices is the list of all registered AS
|
// AppServices is the list of all registered AS
|
||||||
AppServices []config.ApplicationService
|
AppServices []config.ApplicationService
|
||||||
|
@ -55,11 +53,11 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
|
||||||
if req.DataType == "" {
|
if req.DataType == "" {
|
||||||
return fmt.Errorf("data type must not be empty")
|
return fmt.Errorf("data type must not be empty")
|
||||||
}
|
}
|
||||||
return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
|
return a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
|
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
|
||||||
acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
|
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
||||||
switch req.OnConflict {
|
switch req.OnConflict {
|
||||||
|
@ -89,7 +87,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = a.AccountDB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
|
if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -99,7 +97,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
|
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
|
||||||
if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
|
if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
res.PasswordUpdated = true
|
res.PasswordUpdated = true
|
||||||
|
@ -112,7 +110,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
|
||||||
"device_id": req.DeviceID,
|
"device_id": req.DeviceID,
|
||||||
"display_name": req.DeviceDisplayName,
|
"display_name": req.DeviceDisplayName,
|
||||||
}).Info("PerformDeviceCreation")
|
}).Info("PerformDeviceCreation")
|
||||||
dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
|
dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -137,12 +135,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
|
||||||
deletedDeviceIDs := req.DeviceIDs
|
deletedDeviceIDs := req.DeviceIDs
|
||||||
if len(req.DeviceIDs) == 0 {
|
if len(req.DeviceIDs) == 0 {
|
||||||
var devices []api.Device
|
var devices []api.Device
|
||||||
devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
|
devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
|
||||||
for _, d := range devices {
|
for _, d := range devices {
|
||||||
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
|
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs)
|
err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -196,7 +194,7 @@ func (a *UserInternalAPI) PerformLastSeenUpdate(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
||||||
}
|
}
|
||||||
if err := a.DeviceDB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil {
|
if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil {
|
||||||
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
|
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -208,7 +206,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
|
||||||
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID)
|
dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
res.DeviceExists = false
|
res.DeviceExists = false
|
||||||
return nil
|
return nil
|
||||||
|
@ -223,7 +221,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
|
err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
|
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
|
||||||
return err
|
return err
|
||||||
|
@ -261,7 +259,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
|
||||||
if domain != a.ServerName {
|
if domain != a.ServerName {
|
||||||
return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName)
|
return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName)
|
||||||
}
|
}
|
||||||
prof, err := a.AccountDB.GetProfileByLocalpart(ctx, local)
|
prof, err := a.DB.GetProfileByLocalpart(ctx, local)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil
|
return nil
|
||||||
|
@ -275,7 +273,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error {
|
func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error {
|
||||||
profiles, err := a.AccountDB.SearchProfiles(ctx, req.SearchString, req.Limit)
|
profiles, err := a.DB.SearchProfiles(ctx, req.SearchString, req.Limit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -284,7 +282,7 @@ func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.Quer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error {
|
func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error {
|
||||||
devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs)
|
devices, err := a.DB.GetDevicesByID(ctx, req.DeviceIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -312,10 +310,11 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
|
||||||
if domain != a.ServerName {
|
if domain != a.ServerName {
|
||||||
return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName)
|
return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName)
|
||||||
}
|
}
|
||||||
devs, err := a.DeviceDB.GetDevicesByLocalpart(ctx, local)
|
devs, err := a.DB.GetDevicesByLocalpart(ctx, local)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
res.UserExists = true
|
||||||
res.Devices = devs
|
res.Devices = devs
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -330,7 +329,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
|
||||||
}
|
}
|
||||||
if req.DataType != "" {
|
if req.DataType != "" {
|
||||||
var data json.RawMessage
|
var data json.RawMessage
|
||||||
data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
|
data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -348,7 +347,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
global, rooms, err := a.AccountDB.GetAccountData(ctx, local)
|
global, rooms, err := a.DB.GetAccountData(ctx, local)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -367,7 +366,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
device, err := a.DeviceDB.GetDeviceByAccessToken(ctx, req.AccessToken)
|
device, err := a.DB.GetDeviceByAccessToken(ctx, req.AccessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil
|
return nil
|
||||||
|
@ -378,7 +377,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
acc, err := a.AccountDB.GetAccountByLocalpart(ctx, localPart)
|
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -419,7 +418,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
|
||||||
|
|
||||||
if localpart != "" { // AS is masquerading as another user
|
if localpart != "" { // AS is masquerading as another user
|
||||||
// Verify that the user is registered
|
// Verify that the user is registered
|
||||||
account, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart)
|
account, err := a.DB.GetAccountByLocalpart(ctx, localpart)
|
||||||
// Verify that the account exists and either appServiceID matches or
|
// Verify that the account exists and either appServiceID matches or
|
||||||
// it belongs to the appservice user namespaces
|
// it belongs to the appservice user namespaces
|
||||||
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
|
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
|
||||||
|
@ -437,7 +436,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
|
||||||
|
|
||||||
// PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again.
|
// PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again.
|
||||||
func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error {
|
func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error {
|
||||||
err := a.AccountDB.DeactivateAccount(ctx, req.Localpart)
|
err := a.DB.DeactivateAccount(ctx, req.Localpart)
|
||||||
res.AccountDeactivated = err == nil
|
res.AccountDeactivated = err == nil
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -446,7 +445,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
|
||||||
func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error {
|
func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error {
|
||||||
token := util.RandomString(24)
|
token := util.RandomString(24)
|
||||||
|
|
||||||
exp, err := a.AccountDB.CreateOpenIDToken(ctx, token, req.UserID)
|
exp, err := a.DB.CreateOpenIDToken(ctx, token, req.UserID)
|
||||||
|
|
||||||
res.Token = api.OpenIDToken{
|
res.Token = api.OpenIDToken{
|
||||||
Token: token,
|
Token: token,
|
||||||
|
@ -459,7 +458,7 @@ func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *a
|
||||||
|
|
||||||
// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation
|
// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation
|
||||||
func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
|
func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
|
||||||
openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, req.Token)
|
openIDTokenAttrs, err := a.DB.GetOpenIDTokenAttributes(ctx, req.Token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -481,7 +480,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
exists, err := a.AccountDB.DeleteKeyBackup(ctx, req.UserID, req.Version)
|
exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = fmt.Sprintf("failed to delete backup: %s", err)
|
res.Error = fmt.Sprintf("failed to delete backup: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -494,7 +493,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
|
||||||
}
|
}
|
||||||
// Create metadata
|
// Create metadata
|
||||||
if req.Version == "" {
|
if req.Version == "" {
|
||||||
version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
|
version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = fmt.Sprintf("failed to create backup: %s", err)
|
res.Error = fmt.Sprintf("failed to create backup: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -507,7 +506,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
|
||||||
}
|
}
|
||||||
// Update metadata
|
// Update metadata
|
||||||
if len(req.Keys.Rooms) == 0 {
|
if len(req.Keys.Rooms) == 0 {
|
||||||
err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
|
err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = fmt.Sprintf("failed to update backup: %s", err)
|
res.Error = fmt.Sprintf("failed to update backup: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -528,7 +527,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
|
||||||
|
|
||||||
func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
|
func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
|
||||||
// you can only upload keys for the CURRENT version
|
// you can only upload keys for the CURRENT version
|
||||||
version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, "")
|
version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = fmt.Sprintf("failed to query version: %s", err)
|
res.Error = fmt.Sprintf("failed to query version: %s", err)
|
||||||
return
|
return
|
||||||
|
@ -556,7 +555,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
count, etag, err := a.AccountDB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
|
count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = fmt.Sprintf("failed to upsert keys: %s", err)
|
res.Error = fmt.Sprintf("failed to upsert keys: %s", err)
|
||||||
return
|
return
|
||||||
|
@ -566,7 +565,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) {
|
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) {
|
||||||
version, algorithm, authData, etag, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version)
|
version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version)
|
||||||
res.Version = version
|
res.Version = version
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
|
@ -582,14 +581,14 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
|
||||||
res.Exists = !deleted
|
res.Exists = !deleted
|
||||||
|
|
||||||
if !req.ReturnKeys {
|
if !req.ReturnKeys {
|
||||||
res.Count, err = a.AccountDB.CountBackupKeys(ctx, version, req.UserID)
|
res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = fmt.Sprintf("failed to count keys: %s", err)
|
res.Error = fmt.Sprintf("failed to count keys: %s", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := a.AccountDB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
|
result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Error = fmt.Sprintf("failed to query keys: %s", err)
|
res.Error = fmt.Sprintf("failed to query keys: %s", err)
|
||||||
return
|
return
|
||||||
|
|
|
@ -34,7 +34,7 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap
|
||||||
if domain != a.ServerName {
|
if domain != a.ServerName {
|
||||||
return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName)
|
return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName)
|
||||||
}
|
}
|
||||||
tokenMeta, err := a.DeviceDB.CreateLoginToken(ctx, &req.Data)
|
tokenMeta, err := a.DB.CreateLoginToken(ctx, &req.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -45,13 +45,13 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap
|
||||||
// PerformLoginTokenDeletion ensures the token doesn't exist.
|
// PerformLoginTokenDeletion ensures the token doesn't exist.
|
||||||
func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error {
|
func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error {
|
||||||
util.GetLogger(ctx).Info("PerformLoginTokenDeletion")
|
util.GetLogger(ctx).Info("PerformLoginTokenDeletion")
|
||||||
return a.DeviceDB.RemoveLoginToken(ctx, req.Token)
|
return a.DB.RemoveLoginToken(ctx, req.Token)
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryLoginToken returns the data associated with a login token. If
|
// QueryLoginToken returns the data associated with a login token. If
|
||||||
// the token is not valid, success is returned, but res.Data == nil.
|
// the token is not valid, success is returned, but res.Data == nil.
|
||||||
func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error {
|
func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error {
|
||||||
tokenData, err := a.DeviceDB.GetLoginTokenDataByToken(ctx, req.Token)
|
tokenData, err := a.DB.GetLoginTokenDataByToken(ctx, req.Token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Data = nil
|
res.Data = nil
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
|
@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
|
||||||
if domain != a.ServerName {
|
if domain != a.ServerName {
|
||||||
return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName)
|
return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName)
|
||||||
}
|
}
|
||||||
if _, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart); err != nil {
|
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil {
|
||||||
res.Data = nil
|
res.Data = nil
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -1,52 +0,0 @@
|
||||||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package devices
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Database interface {
|
|
||||||
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
|
|
||||||
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
|
|
||||||
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
|
|
||||||
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
|
|
||||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
|
||||||
// If there is already a device with the same device ID for this user, that access token will be revoked
|
|
||||||
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
|
||||||
// an error will be returned.
|
|
||||||
// If no device ID is given one is generated.
|
|
||||||
// Returns the device on success.
|
|
||||||
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
|
|
||||||
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
|
|
||||||
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
|
|
||||||
RemoveDevice(ctx context.Context, deviceID, localpart string) error
|
|
||||||
RemoveDevices(ctx context.Context, localpart string, devices []string) error
|
|
||||||
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
|
|
||||||
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
|
|
||||||
|
|
||||||
// CreateLoginToken generates a token, stores and returns it. The lifetime is
|
|
||||||
// determined by the loginTokenLifetime given to the Database constructor.
|
|
||||||
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
|
|
||||||
|
|
||||||
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
|
|
||||||
RemoveLoginToken(ctx context.Context, token string) error
|
|
||||||
|
|
||||||
// GetLoginTokenDataByToken returns the data associated with the given token.
|
|
||||||
// May return sql.ErrNoRows.
|
|
||||||
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
|
|
||||||
}
|
|
|
@ -1,270 +0,0 @@
|
||||||
// Copyright 2017 Vector Creations Ltd
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package postgres
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/rand"
|
|
||||||
"database/sql"
|
|
||||||
"encoding/base64"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// The length of generated device IDs
|
|
||||||
deviceIDByteLength = 6
|
|
||||||
loginTokenByteLength = 32
|
|
||||||
)
|
|
||||||
|
|
||||||
// Database represents a device database.
|
|
||||||
type Database struct {
|
|
||||||
db *sql.DB
|
|
||||||
devices devicesStatements
|
|
||||||
loginTokens loginTokenStatements
|
|
||||||
loginTokenLifetime time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDatabase creates a new device database
|
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) {
|
|
||||||
db, err := sqlutil.Open(dbProperties)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var d devicesStatements
|
|
||||||
var lt loginTokenStatements
|
|
||||||
|
|
||||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
|
||||||
// and THEN prepare statements so we don't fail due to referencing new columns
|
|
||||||
if err = d.execSchema(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = lt.execSchema(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
m := sqlutil.NewMigrations()
|
|
||||||
deltas.LoadLastSeenTSIP(m)
|
|
||||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = d.prepare(db, serverName); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = lt.prepare(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Database{db, d, lt, loginTokenLifetime}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDeviceByAccessToken returns the device matching the given access token.
|
|
||||||
// Returns sql.ErrNoRows if no matching device was found.
|
|
||||||
func (d *Database) GetDeviceByAccessToken(
|
|
||||||
ctx context.Context, token string,
|
|
||||||
) (*api.Device, error) {
|
|
||||||
return d.devices.selectDeviceByToken(ctx, token)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDeviceByID returns the device matching the given ID.
|
|
||||||
// Returns sql.ErrNoRows if no matching device was found.
|
|
||||||
func (d *Database) GetDeviceByID(
|
|
||||||
ctx context.Context, localpart, deviceID string,
|
|
||||||
) (*api.Device, error) {
|
|
||||||
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
|
||||||
func (d *Database) GetDevicesByLocalpart(
|
|
||||||
ctx context.Context, localpart string,
|
|
||||||
) ([]api.Device, error) {
|
|
||||||
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
|
||||||
return d.devices.selectDevicesByID(ctx, deviceIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
|
||||||
// If there is already a device with the same device ID for this user, that access token will be revoked
|
|
||||||
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
|
||||||
// an error will be returned.
|
|
||||||
// If no device ID is given one is generated.
|
|
||||||
// Returns the device on success.
|
|
||||||
func (d *Database) CreateDevice(
|
|
||||||
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
|
||||||
displayName *string, ipAddr, userAgent string,
|
|
||||||
) (dev *api.Device, returnErr error) {
|
|
||||||
if deviceID != nil {
|
|
||||||
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
var err error
|
|
||||||
// Revoke existing tokens for this device
|
|
||||||
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
// We generate device IDs in a loop in case its already taken.
|
|
||||||
// We cap this at going round 5 times to ensure we don't spin forever
|
|
||||||
var newDeviceID string
|
|
||||||
for i := 1; i <= 5; i++ {
|
|
||||||
newDeviceID, returnErr = generateDeviceID()
|
|
||||||
if returnErr != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
var err error
|
|
||||||
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if returnErr == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateDeviceID creates a new device id. Returns an error if failed to generate
|
|
||||||
// random bytes.
|
|
||||||
func generateDeviceID() (string, error) {
|
|
||||||
b := make([]byte, deviceIDByteLength)
|
|
||||||
_, err := rand.Read(b)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
// url-safe no padding
|
|
||||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateDevice updates the given device with the display name.
|
|
||||||
// Returns SQL error if there are problems and nil on success.
|
|
||||||
func (d *Database) UpdateDevice(
|
|
||||||
ctx context.Context, localpart, deviceID string, displayName *string,
|
|
||||||
) error {
|
|
||||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveDevice revokes a device by deleting the entry in the database
|
|
||||||
// matching with the given device ID and user ID localpart.
|
|
||||||
// If the device doesn't exist, it will not return an error
|
|
||||||
// If something went wrong during the deletion, it will return the SQL error.
|
|
||||||
func (d *Database) RemoveDevice(
|
|
||||||
ctx context.Context, deviceID, localpart string,
|
|
||||||
) error {
|
|
||||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveDevices revokes one or more devices by deleting the entry in the database
|
|
||||||
// matching with the given device IDs and user ID localpart.
|
|
||||||
// If the devices don't exist, it will not return an error
|
|
||||||
// If something went wrong during the deletion, it will return the SQL error.
|
|
||||||
func (d *Database) RemoveDevices(
|
|
||||||
ctx context.Context, localpart string, devices []string,
|
|
||||||
) error {
|
|
||||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveAllDevices revokes devices by deleting the entry in the
|
|
||||||
// database matching the given user ID localpart.
|
|
||||||
// If something went wrong during the deletion, it will return the SQL error.
|
|
||||||
func (d *Database) RemoveAllDevices(
|
|
||||||
ctx context.Context, localpart, exceptDeviceID string,
|
|
||||||
) (devices []api.Device, err error) {
|
|
||||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
|
|
||||||
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
|
|
||||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateLoginToken generates a token, stores and returns it. The lifetime is
|
|
||||||
// determined by the loginTokenLifetime given to the Database constructor.
|
|
||||||
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
|
|
||||||
tok, err := generateLoginToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
meta := &api.LoginTokenMetadata{
|
|
||||||
Token: tok,
|
|
||||||
Expiration: time.Now().Add(d.loginTokenLifetime),
|
|
||||||
}
|
|
||||||
|
|
||||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
return d.loginTokens.insert(ctx, txn, meta, data)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return meta, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateLoginToken() (string, error) {
|
|
||||||
b := make([]byte, loginTokenByteLength)
|
|
||||||
_, err := rand.Read(b)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
|
|
||||||
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
|
|
||||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
return d.loginTokens.deleteByToken(ctx, txn, token)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLoginTokenDataByToken returns the data associated with the given token.
|
|
||||||
// May return sql.ErrNoRows.
|
|
||||||
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
|
|
||||||
return d.loginTokens.selectByToken(ctx, token)
|
|
||||||
}
|
|
|
@ -1,271 +0,0 @@
|
||||||
// Copyright 2017 Vector Creations Ltd
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package sqlite3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/rand"
|
|
||||||
"database/sql"
|
|
||||||
"encoding/base64"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// The length of generated device IDs
|
|
||||||
deviceIDByteLength = 6
|
|
||||||
|
|
||||||
loginTokenByteLength = 32
|
|
||||||
)
|
|
||||||
|
|
||||||
// Database represents a device database.
|
|
||||||
type Database struct {
|
|
||||||
db *sql.DB
|
|
||||||
writer sqlutil.Writer
|
|
||||||
devices devicesStatements
|
|
||||||
loginTokens loginTokenStatements
|
|
||||||
loginTokenLifetime time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDatabase creates a new device database
|
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) {
|
|
||||||
db, err := sqlutil.Open(dbProperties)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
writer := sqlutil.NewExclusiveWriter()
|
|
||||||
var d devicesStatements
|
|
||||||
var lt loginTokenStatements
|
|
||||||
|
|
||||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
|
||||||
// and THEN prepare statements so we don't fail due to referencing new columns
|
|
||||||
if err = d.execSchema(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = lt.execSchema(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
m := sqlutil.NewMigrations()
|
|
||||||
deltas.LoadLastSeenTSIP(m)
|
|
||||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.prepare(db, writer, serverName); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = lt.prepare(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &Database{db, writer, d, lt, loginTokenLifetime}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDeviceByAccessToken returns the device matching the given access token.
|
|
||||||
// Returns sql.ErrNoRows if no matching device was found.
|
|
||||||
func (d *Database) GetDeviceByAccessToken(
|
|
||||||
ctx context.Context, token string,
|
|
||||||
) (*api.Device, error) {
|
|
||||||
return d.devices.selectDeviceByToken(ctx, token)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDeviceByID returns the device matching the given ID.
|
|
||||||
// Returns sql.ErrNoRows if no matching device was found.
|
|
||||||
func (d *Database) GetDeviceByID(
|
|
||||||
ctx context.Context, localpart, deviceID string,
|
|
||||||
) (*api.Device, error) {
|
|
||||||
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
|
||||||
func (d *Database) GetDevicesByLocalpart(
|
|
||||||
ctx context.Context, localpart string,
|
|
||||||
) ([]api.Device, error) {
|
|
||||||
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
|
||||||
return d.devices.selectDevicesByID(ctx, deviceIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
|
||||||
// If there is already a device with the same device ID for this user, that access token will be revoked
|
|
||||||
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
|
||||||
// an error will be returned.
|
|
||||||
// If no device ID is given one is generated.
|
|
||||||
// Returns the device on success.
|
|
||||||
func (d *Database) CreateDevice(
|
|
||||||
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
|
||||||
displayName *string, ipAddr, userAgent string,
|
|
||||||
) (dev *api.Device, returnErr error) {
|
|
||||||
if deviceID != nil {
|
|
||||||
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
var err error
|
|
||||||
// Revoke existing tokens for this device
|
|
||||||
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
// We generate device IDs in a loop in case its already taken.
|
|
||||||
// We cap this at going round 5 times to ensure we don't spin forever
|
|
||||||
var newDeviceID string
|
|
||||||
for i := 1; i <= 5; i++ {
|
|
||||||
newDeviceID, returnErr = generateDeviceID()
|
|
||||||
if returnErr != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
var err error
|
|
||||||
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if returnErr == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateDeviceID creates a new device id. Returns an error if failed to generate
|
|
||||||
// random bytes.
|
|
||||||
func generateDeviceID() (string, error) {
|
|
||||||
b := make([]byte, deviceIDByteLength)
|
|
||||||
_, err := rand.Read(b)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
// url-safe no padding
|
|
||||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateDevice updates the given device with the display name.
|
|
||||||
// Returns SQL error if there are problems and nil on success.
|
|
||||||
func (d *Database) UpdateDevice(
|
|
||||||
ctx context.Context, localpart, deviceID string, displayName *string,
|
|
||||||
) error {
|
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveDevice revokes a device by deleting the entry in the database
|
|
||||||
// matching with the given device ID and user ID localpart.
|
|
||||||
// If the device doesn't exist, it will not return an error
|
|
||||||
// If something went wrong during the deletion, it will return the SQL error.
|
|
||||||
func (d *Database) RemoveDevice(
|
|
||||||
ctx context.Context, deviceID, localpart string,
|
|
||||||
) error {
|
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveDevices revokes one or more devices by deleting the entry in the database
|
|
||||||
// matching with the given device IDs and user ID localpart.
|
|
||||||
// If the devices don't exist, it will not return an error
|
|
||||||
// If something went wrong during the deletion, it will return the SQL error.
|
|
||||||
func (d *Database) RemoveDevices(
|
|
||||||
ctx context.Context, localpart string, devices []string,
|
|
||||||
) error {
|
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveAllDevices revokes devices by deleting the entry in the
|
|
||||||
// database matching the given user ID localpart.
|
|
||||||
// If something went wrong during the deletion, it will return the SQL error.
|
|
||||||
func (d *Database) RemoveAllDevices(
|
|
||||||
ctx context.Context, localpart, exceptDeviceID string,
|
|
||||||
) (devices []api.Device, err error) {
|
|
||||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
|
|
||||||
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
|
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateLoginToken generates a token, stores and returns it. The lifetime is
|
|
||||||
// determined by the loginTokenLifetime given to the Database constructor.
|
|
||||||
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
|
|
||||||
tok, err := generateLoginToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
meta := &api.LoginTokenMetadata{
|
|
||||||
Token: tok,
|
|
||||||
Expiration: time.Now().Add(d.loginTokenLifetime),
|
|
||||||
}
|
|
||||||
|
|
||||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
return d.loginTokens.insert(ctx, txn, meta, data)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return meta, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateLoginToken() (string, error) {
|
|
||||||
b := make([]byte, loginTokenByteLength)
|
|
||||||
_, err := rand.Read(b)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
|
|
||||||
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
|
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
return d.loginTokens.deleteByToken(ctx, txn, token)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLoginTokenDataByToken returns the data associated with the given token.
|
|
||||||
// May return sql.ErrNoRows.
|
|
||||||
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
|
|
||||||
return d.loginTokens.selectByToken(ctx, token)
|
|
||||||
}
|
|
|
@ -1,42 +0,0 @@
|
||||||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
//go:build !wasm
|
|
||||||
// +build !wasm
|
|
||||||
|
|
||||||
package devices
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/devices/postgres"
|
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
|
|
||||||
// and sets postgres connection parameters. loginTokenLifetime determines how long a
|
|
||||||
// login token from CreateLoginToken is valid.
|
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (Database, error) {
|
|
||||||
switch {
|
|
||||||
case dbProperties.ConnectionString.IsSQLite():
|
|
||||||
return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime)
|
|
||||||
case dbProperties.ConnectionString.IsPostgres():
|
|
||||||
return postgres.NewDatabase(dbProperties, serverName, loginTokenLifetime)
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unexpected database type")
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,39 +0,0 @@
|
||||||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package devices
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
|
||||||
|
|
||||||
func NewDatabase(
|
|
||||||
dbProperties *config.DatabaseOptions,
|
|
||||||
serverName gomatrixserverlib.ServerName,
|
|
||||||
loginTokenLifetime time.Duration,
|
|
||||||
) (Database, error) {
|
|
||||||
switch {
|
|
||||||
case dbProperties.ConnectionString.IsSQLite():
|
|
||||||
return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime)
|
|
||||||
case dbProperties.ConnectionString.IsPostgres():
|
|
||||||
return nil, fmt.Errorf("can't use Postgres implementation")
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unexpected database type")
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -12,7 +12,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package accounts
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -60,6 +60,35 @@ type Database interface {
|
||||||
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
|
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
|
||||||
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
|
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
|
||||||
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
|
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
|
||||||
|
|
||||||
|
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
|
||||||
|
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
|
||||||
|
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
|
||||||
|
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
|
||||||
|
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||||
|
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||||
|
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
||||||
|
// an error will be returned.
|
||||||
|
// If no device ID is given one is generated.
|
||||||
|
// Returns the device on success.
|
||||||
|
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
|
||||||
|
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
|
||||||
|
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
|
||||||
|
RemoveDevice(ctx context.Context, deviceID, localpart string) error
|
||||||
|
RemoveDevices(ctx context.Context, localpart string, devices []string) error
|
||||||
|
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
|
||||||
|
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
|
||||||
|
|
||||||
|
// CreateLoginToken generates a token, stores and returns it. The lifetime is
|
||||||
|
// determined by the loginTokenLifetime given to the Database constructor.
|
||||||
|
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
|
||||||
|
|
||||||
|
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
|
||||||
|
RemoveLoginToken(ctx context.Context, token string) error
|
||||||
|
|
||||||
|
// GetLoginTokenDataByToken returns the data associated with the given token.
|
||||||
|
// May return sql.ErrNoRows.
|
||||||
|
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Err3PIDInUse is the error returned when trying to save an association involving
|
// Err3PIDInUse is the error returned when trying to save an association involving
|
|
@ -5,13 +5,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/pressly/goose"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func LoadFromGoose() {
|
|
||||||
goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadLastSeenTSIP(m *sqlutil.Migrations) {
|
func LoadLastSeenTSIP(m *sqlutil.Migrations) {
|
||||||
m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
|
m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
|
||||||
}
|
}
|
|
@ -117,6 +117,9 @@ func (s *devicesStatements) execSchema(db *sql.DB) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||||
|
if err = s.execSchema(db); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
|
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
|
@ -51,6 +51,9 @@ CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_exp
|
||||||
|
|
||||||
// prepare runs statement preparation.
|
// prepare runs statement preparation.
|
||||||
func (s *loginTokenStatements) prepare(db *sql.DB) error {
|
func (s *loginTokenStatements) prepare(db *sql.DB) error {
|
||||||
|
if err := s.execSchema(db); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return sqlutil.StatementList{
|
return sqlutil.StatementList{
|
||||||
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"},
|
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"},
|
||||||
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"},
|
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"},
|
|
@ -16,7 +16,9 @@ package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -30,7 +32,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas"
|
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
|
||||||
|
|
||||||
// Import the postgres database driver.
|
// Import the postgres database driver.
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
|
@ -47,14 +49,23 @@ type Database struct {
|
||||||
threepids threepidStatements
|
threepids threepidStatements
|
||||||
openIDTokens tokenStatements
|
openIDTokens tokenStatements
|
||||||
keyBackupVersions keyBackupVersionStatements
|
keyBackupVersions keyBackupVersionStatements
|
||||||
|
devices devicesStatements
|
||||||
|
loginTokens loginTokenStatements
|
||||||
|
loginTokenLifetime time.Duration
|
||||||
keyBackups keyBackupStatements
|
keyBackups keyBackupStatements
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
bcryptCost int
|
bcryptCost int
|
||||||
openIDTokenLifetimeMS int64
|
openIDTokenLifetimeMS int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// The length of generated device IDs
|
||||||
|
deviceIDByteLength = 6
|
||||||
|
loginTokenByteLength = 32
|
||||||
|
)
|
||||||
|
|
||||||
// NewDatabase creates a new accounts and profiles database
|
// NewDatabase creates a new accounts and profiles database
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
|
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*Database, error) {
|
||||||
db, err := sqlutil.Open(dbProperties)
|
db, err := sqlutil.Open(dbProperties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -63,6 +74,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
||||||
serverName: serverName,
|
serverName: serverName,
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewDummyWriter(),
|
writer: sqlutil.NewDummyWriter(),
|
||||||
|
loginTokenLifetime: loginTokenLifetime,
|
||||||
bcryptCost: bcryptCost,
|
bcryptCost: bcryptCost,
|
||||||
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||||
}
|
}
|
||||||
|
@ -74,6 +86,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
||||||
}
|
}
|
||||||
m := sqlutil.NewMigrations()
|
m := sqlutil.NewMigrations()
|
||||||
deltas.LoadIsActive(m)
|
deltas.LoadIsActive(m)
|
||||||
|
//deltas.LoadLastSeenTSIP(m)
|
||||||
deltas.LoadAddAccountType(m)
|
deltas.LoadAddAccountType(m)
|
||||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -103,6 +116,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
||||||
if err = d.keyBackups.prepare(db); err != nil {
|
if err = d.keyBackups.prepare(db); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if err = d.devices.prepare(db, serverName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = d.loginTokens.prepare(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
@ -515,3 +534,196 @@ func (d *Database) UpsertBackupKeys(
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDeviceByAccessToken returns the device matching the given access token.
|
||||||
|
// Returns sql.ErrNoRows if no matching device was found.
|
||||||
|
func (d *Database) GetDeviceByAccessToken(
|
||||||
|
ctx context.Context, token string,
|
||||||
|
) (*api.Device, error) {
|
||||||
|
return d.devices.selectDeviceByToken(ctx, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDeviceByID returns the device matching the given ID.
|
||||||
|
// Returns sql.ErrNoRows if no matching device was found.
|
||||||
|
func (d *Database) GetDeviceByID(
|
||||||
|
ctx context.Context, localpart, deviceID string,
|
||||||
|
) (*api.Device, error) {
|
||||||
|
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
||||||
|
func (d *Database) GetDevicesByLocalpart(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) ([]api.Device, error) {
|
||||||
|
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||||
|
return d.devices.selectDevicesByID(ctx, deviceIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||||
|
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||||
|
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
||||||
|
// an error will be returned.
|
||||||
|
// If no device ID is given one is generated.
|
||||||
|
// Returns the device on success.
|
||||||
|
func (d *Database) CreateDevice(
|
||||||
|
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
||||||
|
displayName *string, ipAddr, userAgent string,
|
||||||
|
) (dev *api.Device, returnErr error) {
|
||||||
|
if deviceID != nil {
|
||||||
|
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
var err error
|
||||||
|
// Revoke existing tokens for this device
|
||||||
|
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// We generate device IDs in a loop in case its already taken.
|
||||||
|
// We cap this at going round 5 times to ensure we don't spin forever
|
||||||
|
var newDeviceID string
|
||||||
|
for i := 1; i <= 5; i++ {
|
||||||
|
newDeviceID, returnErr = generateDeviceID()
|
||||||
|
if returnErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
var err error
|
||||||
|
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if returnErr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateDeviceID creates a new device id. Returns an error if failed to generate
|
||||||
|
// random bytes.
|
||||||
|
func generateDeviceID() (string, error) {
|
||||||
|
b := make([]byte, deviceIDByteLength)
|
||||||
|
_, err := rand.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// url-safe no padding
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDevice updates the given device with the display name.
|
||||||
|
// Returns SQL error if there are problems and nil on success.
|
||||||
|
func (d *Database) UpdateDevice(
|
||||||
|
ctx context.Context, localpart, deviceID string, displayName *string,
|
||||||
|
) error {
|
||||||
|
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveDevice revokes a device by deleting the entry in the database
|
||||||
|
// matching with the given device ID and user ID localpart.
|
||||||
|
// If the device doesn't exist, it will not return an error
|
||||||
|
// If something went wrong during the deletion, it will return the SQL error.
|
||||||
|
func (d *Database) RemoveDevice(
|
||||||
|
ctx context.Context, deviceID, localpart string,
|
||||||
|
) error {
|
||||||
|
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveDevices revokes one or more devices by deleting the entry in the database
|
||||||
|
// matching with the given device IDs and user ID localpart.
|
||||||
|
// If the devices don't exist, it will not return an error
|
||||||
|
// If something went wrong during the deletion, it will return the SQL error.
|
||||||
|
func (d *Database) RemoveDevices(
|
||||||
|
ctx context.Context, localpart string, devices []string,
|
||||||
|
) error {
|
||||||
|
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveAllDevices revokes devices by deleting the entry in the
|
||||||
|
// database matching the given user ID localpart.
|
||||||
|
// If something went wrong during the deletion, it will return the SQL error.
|
||||||
|
func (d *Database) RemoveAllDevices(
|
||||||
|
ctx context.Context, localpart, exceptDeviceID string,
|
||||||
|
) (devices []api.Device, err error) {
|
||||||
|
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
|
||||||
|
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
|
||||||
|
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateLoginToken generates a token, stores and returns it. The lifetime is
|
||||||
|
// determined by the loginTokenLifetime given to the Database constructor.
|
||||||
|
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
|
||||||
|
tok, err := generateLoginToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
meta := &api.LoginTokenMetadata{
|
||||||
|
Token: tok,
|
||||||
|
Expiration: time.Now().Add(d.loginTokenLifetime),
|
||||||
|
}
|
||||||
|
|
||||||
|
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
return d.loginTokens.insert(ctx, txn, meta, data)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return meta, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateLoginToken() (string, error) {
|
||||||
|
b := make([]byte, loginTokenByteLength)
|
||||||
|
_, err := rand.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
|
||||||
|
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
|
||||||
|
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
return d.loginTokens.deleteByToken(ctx, txn, token)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLoginTokenDataByToken returns the data associated with the given token.
|
||||||
|
// May return sql.ErrNoRows.
|
||||||
|
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
|
||||||
|
return d.loginTokens.selectByToken(ctx, token)
|
||||||
|
}
|
|
@ -5,13 +5,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/pressly/goose"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func LoadFromGoose() {
|
|
||||||
goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadLastSeenTSIP(m *sqlutil.Migrations) {
|
func LoadLastSeenTSIP(m *sqlutil.Migrations) {
|
||||||
m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
|
m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
|
||||||
}
|
}
|
|
@ -106,6 +106,9 @@ func (s *devicesStatements) execSchema(db *sql.DB) error {
|
||||||
func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
|
func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
|
||||||
s.db = db
|
s.db = db
|
||||||
s.writer = writer
|
s.writer = writer
|
||||||
|
if err = s.execSchema(db); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
|
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
|
@ -51,6 +51,9 @@ CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_exp
|
||||||
|
|
||||||
// prepare runs statement preparation.
|
// prepare runs statement preparation.
|
||||||
func (s *loginTokenStatements) prepare(db *sql.DB) error {
|
func (s *loginTokenStatements) prepare(db *sql.DB) error {
|
||||||
|
if err := s.execSchema(db); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return sqlutil.StatementList{
|
return sqlutil.StatementList{
|
||||||
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"},
|
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"},
|
||||||
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"},
|
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"},
|
|
@ -16,7 +16,9 @@ package sqlite3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -31,7 +33,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas"
|
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Database represents an account database
|
// Database represents an account database
|
||||||
|
@ -47,6 +49,9 @@ type Database struct {
|
||||||
openIDTokens tokenStatements
|
openIDTokens tokenStatements
|
||||||
keyBackupVersions keyBackupVersionStatements
|
keyBackupVersions keyBackupVersionStatements
|
||||||
keyBackups keyBackupStatements
|
keyBackups keyBackupStatements
|
||||||
|
devices devicesStatements
|
||||||
|
loginTokens loginTokenStatements
|
||||||
|
loginTokenLifetime time.Duration
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
bcryptCost int
|
bcryptCost int
|
||||||
openIDTokenLifetimeMS int64
|
openIDTokenLifetimeMS int64
|
||||||
|
@ -57,8 +62,14 @@ type Database struct {
|
||||||
threepidsMu sync.Mutex
|
threepidsMu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// The length of generated device IDs
|
||||||
|
deviceIDByteLength = 6
|
||||||
|
loginTokenByteLength = 32
|
||||||
|
)
|
||||||
|
|
||||||
// NewDatabase creates a new accounts and profiles database
|
// NewDatabase creates a new accounts and profiles database
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
|
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*Database, error) {
|
||||||
db, err := sqlutil.Open(dbProperties)
|
db, err := sqlutil.Open(dbProperties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -67,6 +78,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
||||||
serverName: serverName,
|
serverName: serverName,
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewExclusiveWriter(),
|
writer: sqlutil.NewExclusiveWriter(),
|
||||||
|
loginTokenLifetime: loginTokenLifetime,
|
||||||
bcryptCost: bcryptCost,
|
bcryptCost: bcryptCost,
|
||||||
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||||
}
|
}
|
||||||
|
@ -78,6 +90,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
||||||
}
|
}
|
||||||
m := sqlutil.NewMigrations()
|
m := sqlutil.NewMigrations()
|
||||||
deltas.LoadIsActive(m)
|
deltas.LoadIsActive(m)
|
||||||
|
//deltas.LoadLastSeenTSIP(m)
|
||||||
deltas.LoadAddAccountType(m)
|
deltas.LoadAddAccountType(m)
|
||||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -108,6 +121,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
||||||
if err = d.keyBackups.prepare(db); err != nil {
|
if err = d.keyBackups.prepare(db); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if err = d.devices.prepare(db, d.writer, serverName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = d.loginTokens.prepare(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
@ -547,3 +566,196 @@ func (d *Database) UpsertBackupKeys(
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDeviceByAccessToken returns the device matching the given access token.
|
||||||
|
// Returns sql.ErrNoRows if no matching device was found.
|
||||||
|
func (d *Database) GetDeviceByAccessToken(
|
||||||
|
ctx context.Context, token string,
|
||||||
|
) (*api.Device, error) {
|
||||||
|
return d.devices.selectDeviceByToken(ctx, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDeviceByID returns the device matching the given ID.
|
||||||
|
// Returns sql.ErrNoRows if no matching device was found.
|
||||||
|
func (d *Database) GetDeviceByID(
|
||||||
|
ctx context.Context, localpart, deviceID string,
|
||||||
|
) (*api.Device, error) {
|
||||||
|
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
||||||
|
func (d *Database) GetDevicesByLocalpart(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) ([]api.Device, error) {
|
||||||
|
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||||
|
return d.devices.selectDevicesByID(ctx, deviceIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||||
|
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||||
|
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
||||||
|
// an error will be returned.
|
||||||
|
// If no device ID is given one is generated.
|
||||||
|
// Returns the device on success.
|
||||||
|
func (d *Database) CreateDevice(
|
||||||
|
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
||||||
|
displayName *string, ipAddr, userAgent string,
|
||||||
|
) (dev *api.Device, returnErr error) {
|
||||||
|
if deviceID != nil {
|
||||||
|
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
var err error
|
||||||
|
// Revoke existing tokens for this device
|
||||||
|
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// We generate device IDs in a loop in case its already taken.
|
||||||
|
// We cap this at going round 5 times to ensure we don't spin forever
|
||||||
|
var newDeviceID string
|
||||||
|
for i := 1; i <= 5; i++ {
|
||||||
|
newDeviceID, returnErr = generateDeviceID()
|
||||||
|
if returnErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
var err error
|
||||||
|
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if returnErr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateDeviceID creates a new device id. Returns an error if failed to generate
|
||||||
|
// random bytes.
|
||||||
|
func generateDeviceID() (string, error) {
|
||||||
|
b := make([]byte, deviceIDByteLength)
|
||||||
|
_, err := rand.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// url-safe no padding
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDevice updates the given device with the display name.
|
||||||
|
// Returns SQL error if there are problems and nil on success.
|
||||||
|
func (d *Database) UpdateDevice(
|
||||||
|
ctx context.Context, localpart, deviceID string, displayName *string,
|
||||||
|
) error {
|
||||||
|
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveDevice revokes a device by deleting the entry in the database
|
||||||
|
// matching with the given device ID and user ID localpart.
|
||||||
|
// If the device doesn't exist, it will not return an error
|
||||||
|
// If something went wrong during the deletion, it will return the SQL error.
|
||||||
|
func (d *Database) RemoveDevice(
|
||||||
|
ctx context.Context, deviceID, localpart string,
|
||||||
|
) error {
|
||||||
|
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveDevices revokes one or more devices by deleting the entry in the database
|
||||||
|
// matching with the given device IDs and user ID localpart.
|
||||||
|
// If the devices don't exist, it will not return an error
|
||||||
|
// If something went wrong during the deletion, it will return the SQL error.
|
||||||
|
func (d *Database) RemoveDevices(
|
||||||
|
ctx context.Context, localpart string, devices []string,
|
||||||
|
) error {
|
||||||
|
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveAllDevices revokes devices by deleting the entry in the
|
||||||
|
// database matching the given user ID localpart.
|
||||||
|
// If something went wrong during the deletion, it will return the SQL error.
|
||||||
|
func (d *Database) RemoveAllDevices(
|
||||||
|
ctx context.Context, localpart, exceptDeviceID string,
|
||||||
|
) (devices []api.Device, err error) {
|
||||||
|
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
|
||||||
|
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
|
||||||
|
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateLoginToken generates a token, stores and returns it. The lifetime is
|
||||||
|
// determined by the loginTokenLifetime given to the Database constructor.
|
||||||
|
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
|
||||||
|
tok, err := generateLoginToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
meta := &api.LoginTokenMetadata{
|
||||||
|
Token: tok,
|
||||||
|
Expiration: time.Now().Add(d.loginTokenLifetime),
|
||||||
|
}
|
||||||
|
|
||||||
|
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.loginTokens.insert(ctx, txn, meta, data)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return meta, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateLoginToken() (string, error) {
|
||||||
|
b := make([]byte, loginTokenByteLength)
|
||||||
|
_, err := rand.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
|
||||||
|
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
|
||||||
|
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.loginTokens.deleteByToken(ctx, txn, token)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLoginTokenDataByToken returns the data associated with the given token.
|
||||||
|
// May return sql.ErrNoRows.
|
||||||
|
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
|
||||||
|
return d.loginTokens.selectByToken(ctx, token)
|
||||||
|
}
|
|
@ -15,26 +15,27 @@
|
||||||
//go:build !wasm
|
//go:build !wasm
|
||||||
// +build !wasm
|
// +build !wasm
|
||||||
|
|
||||||
package accounts
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts/postgres"
|
"github.com/matrix-org/dendrite/userapi/storage/postgres"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3"
|
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
|
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
|
||||||
// and sets postgres connection parameters
|
// and sets postgres connection parameters
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (Database, error) {
|
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (Database, error) {
|
||||||
switch {
|
switch {
|
||||||
case dbProperties.ConnectionString.IsSQLite():
|
case dbProperties.ConnectionString.IsSQLite():
|
||||||
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
|
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime)
|
||||||
case dbProperties.ConnectionString.IsPostgres():
|
case dbProperties.ConnectionString.IsPostgres():
|
||||||
return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
|
return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unexpected database type")
|
return nil, fmt.Errorf("unexpected database type")
|
||||||
}
|
}
|
|
@ -12,13 +12,14 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package accounts
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3"
|
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -27,10 +28,11 @@ func NewDatabase(
|
||||||
serverName gomatrixserverlib.ServerName,
|
serverName gomatrixserverlib.ServerName,
|
||||||
bcryptCost int,
|
bcryptCost int,
|
||||||
openIDTokenLifetimeMS int64,
|
openIDTokenLifetimeMS int64,
|
||||||
|
loginTokenLifetime time.Duration,
|
||||||
) (Database, error) {
|
) (Database, error) {
|
||||||
switch {
|
switch {
|
||||||
case dbProperties.ConnectionString.IsSQLite():
|
case dbProperties.ConnectionString.IsSQLite():
|
||||||
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
|
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime)
|
||||||
case dbProperties.ConnectionString.IsPostgres():
|
case dbProperties.ConnectionString.IsPostgres():
|
||||||
return nil, fmt.Errorf("can't use Postgres implementation")
|
return nil, fmt.Errorf("can't use Postgres implementation")
|
||||||
default:
|
default:
|
|
@ -23,18 +23,10 @@ import (
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/internal"
|
"github.com/matrix-org/dendrite/userapi/internal"
|
||||||
"github.com/matrix-org/dendrite/userapi/inthttp"
|
"github.com/matrix-org/dendrite/userapi/inthttp"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
"github.com/matrix-org/dendrite/userapi/storage"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/devices"
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// defaultLoginTokenLifetime determines how old a valid token may be.
|
|
||||||
//
|
|
||||||
// NOTSPEC: The current spec says "SHOULD be limited to around five
|
|
||||||
// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low.
|
|
||||||
// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325).
|
|
||||||
const defaultLoginTokenLifetime = 2 * time.Minute
|
|
||||||
|
|
||||||
// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions
|
// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions
|
||||||
// on the given input API.
|
// on the given input API.
|
||||||
func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
|
func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
|
||||||
|
@ -44,26 +36,24 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
|
||||||
// NewInternalAPI returns a concerete implementation of the internal API. Callers
|
// NewInternalAPI returns a concerete implementation of the internal API. Callers
|
||||||
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
||||||
func NewInternalAPI(
|
func NewInternalAPI(
|
||||||
accountDB accounts.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI,
|
accountDB storage.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI,
|
||||||
) api.UserInternalAPI {
|
) api.UserInternalAPI {
|
||||||
deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName, defaultLoginTokenLifetime)
|
db, err := storage.NewDatabase(&cfg.AccountDatabase, cfg.Matrix.ServerName, cfg.BCryptCost, int64(api.DefaultLoginTokenLifetime*time.Millisecond), api.DefaultLoginTokenLifetime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Panicf("failed to connect to device db")
|
logrus.WithError(err).Panicf("failed to connect to device db")
|
||||||
}
|
}
|
||||||
|
|
||||||
return newInternalAPI(accountDB, deviceDB, cfg, appServices, keyAPI)
|
return newInternalAPI(db, cfg, appServices, keyAPI)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newInternalAPI(
|
func newInternalAPI(
|
||||||
accountDB accounts.Database,
|
db storage.Database,
|
||||||
deviceDB devices.Database,
|
|
||||||
cfg *config.UserAPI,
|
cfg *config.UserAPI,
|
||||||
appServices []config.ApplicationService,
|
appServices []config.ApplicationService,
|
||||||
keyAPI keyapi.KeyInternalAPI,
|
keyAPI keyapi.KeyInternalAPI,
|
||||||
) api.UserInternalAPI {
|
) api.UserInternalAPI {
|
||||||
return &internal.UserInternalAPI{
|
return &internal.UserInternalAPI{
|
||||||
AccountDB: accountDB,
|
DB: db,
|
||||||
DeviceDB: deviceDB,
|
|
||||||
ServerName: cfg.Matrix.ServerName,
|
ServerName: cfg.Matrix.ServerName,
|
||||||
AppServices: appServices,
|
AppServices: appServices,
|
||||||
KeyAPI: keyAPI,
|
KeyAPI: keyAPI,
|
||||||
|
|
|
@ -31,8 +31,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
"github.com/matrix-org/dendrite/userapi/inthttp"
|
"github.com/matrix-org/dendrite/userapi/inthttp"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
"github.com/matrix-org/dendrite/userapi/storage"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/devices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -43,23 +42,19 @@ type apiTestOpts struct {
|
||||||
loginTokenLifetime time.Duration
|
loginTokenLifetime time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, accounts.Database) {
|
func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, storage.Database) {
|
||||||
if opts.loginTokenLifetime == 0 {
|
if opts.loginTokenLifetime == 0 {
|
||||||
opts.loginTokenLifetime = defaultLoginTokenLifetime
|
opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond
|
||||||
}
|
}
|
||||||
dbopts := &config.DatabaseOptions{
|
dbopts := &config.DatabaseOptions{
|
||||||
ConnectionString: "file::memory:",
|
ConnectionString: "file::memory:",
|
||||||
MaxOpenConnections: 1,
|
MaxOpenConnections: 1,
|
||||||
MaxIdleConnections: 1,
|
MaxIdleConnections: 1,
|
||||||
}
|
}
|
||||||
accountDB, err := accounts.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS)
|
accountDB, err := storage.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create account DB: %s", err)
|
t.Fatalf("failed to create account DB: %s", err)
|
||||||
}
|
}
|
||||||
deviceDB, err := devices.NewDatabase(dbopts, serverName, opts.loginTokenLifetime)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to create device DB: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg := &config.UserAPI{
|
cfg := &config.UserAPI{
|
||||||
Matrix: &config.Global{
|
Matrix: &config.Global{
|
||||||
|
@ -67,7 +62,7 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, a
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return newInternalAPI(accountDB, deviceDB, cfg, nil, nil), accountDB
|
return newInternalAPI(accountDB, cfg, nil, nil), accountDB
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryProfile(t *testing.T) {
|
func TestQueryProfile(t *testing.T) {
|
||||||
|
|
Loading…
Reference in a new issue