mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-26 15:08:28 +00:00
* Implement OpenID module (#599) - Unrelated: change Riot references to Element in client API routing Signed-off-by: Bruce MacDonald <contact@bruce-macdonald.com> * OpenID module tweaks (#599) - specify expiry is ms rather than vague ts - add OpenID token lifetime to configuration - use Go naming conventions for the path params - store plaintext token rather than hash - remove openid table sqllite mutex * Add default OpenID token lifetime (#599) * Update dendrite-config.yaml Co-authored-by: Kegsay <kegsay@gmail.com> Co-authored-by: Kegsay <kegan@matrix.org>
This commit is contained in:
parent
f8d3a762c4
commit
d27607af78
23 changed files with 553 additions and 32 deletions
70
clientapi/routing/openid.go
Normal file
70
clientapi/routing/openid.go
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openIDTokenResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
MatrixServerName string `json:"matrix_server_name"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateOpenIDToken creates a new OpenID Connect (OIDC) token that a Matrix user
|
||||||
|
// can supply to an OpenID Relying Party to verify their identity
|
||||||
|
func CreateOpenIDToken(
|
||||||
|
req *http.Request,
|
||||||
|
userAPI api.UserInternalAPI,
|
||||||
|
device *api.Device,
|
||||||
|
userID string,
|
||||||
|
cfg *config.ClientAPI,
|
||||||
|
) util.JSONResponse {
|
||||||
|
// does the incoming user ID match the user that the token was issued for?
|
||||||
|
if userID != device.UserID {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
JSON: jsonerror.Forbidden("Cannot request tokens for other users"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
request := api.PerformOpenIDTokenCreationRequest{
|
||||||
|
UserID: userID, // this is the user ID from the incoming path
|
||||||
|
}
|
||||||
|
response := api.PerformOpenIDTokenCreationResponse{}
|
||||||
|
|
||||||
|
err := userAPI.PerformOpenIDTokenCreation(req.Context(), &request, &response)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("userAPI.CreateOpenIDToken failed")
|
||||||
|
return jsonerror.InternalServerError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusOK,
|
||||||
|
JSON: openIDTokenResponse{
|
||||||
|
AccessToken: response.Token.Token,
|
||||||
|
TokenType: "Bearer",
|
||||||
|
MatrixServerName: string(cfg.Matrix.ServerName),
|
||||||
|
ExpiresIn: response.Token.ExpiresAtMS / 1000, // convert ms to s
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
|
@ -469,7 +469,7 @@ func Setup(
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
// Stub endpoints required by Riot
|
// Stub endpoints required by Element
|
||||||
|
|
||||||
r0mux.Handle("/login",
|
r0mux.Handle("/login",
|
||||||
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
|
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
|
||||||
|
@ -506,7 +506,7 @@ func Setup(
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
// Riot user settings
|
// Element user settings
|
||||||
|
|
||||||
r0mux.Handle("/profile/{userID}",
|
r0mux.Handle("/profile/{userID}",
|
||||||
httputil.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse {
|
httputil.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse {
|
||||||
|
@ -592,7 +592,7 @@ func Setup(
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
// Riot logs get flooded unless this is handled
|
// Element logs get flooded unless this is handled
|
||||||
r0mux.Handle("/presence/{userID}/status",
|
r0mux.Handle("/presence/{userID}/status",
|
||||||
httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse {
|
httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse {
|
||||||
if r := rateLimits.rateLimit(req); r != nil {
|
if r := rateLimits.rateLimit(req); r != nil {
|
||||||
|
@ -685,6 +685,19 @@ func Setup(
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet)
|
).Methods(http.MethodGet)
|
||||||
|
|
||||||
|
r0mux.Handle("/user/{userID}/openid/request_token",
|
||||||
|
httputil.MakeAuthAPI("openid_request_token", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
if r := rateLimits.rateLimit(req); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return CreateOpenIDToken(req, userAPI, device, vars["userID"], cfg)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
r0mux.Handle("/user_directory/search",
|
r0mux.Handle("/user_directory/search",
|
||||||
httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
if r := rateLimits.rateLimit(req); r != nil {
|
if r := rateLimits.rateLimit(req); r != nil {
|
||||||
|
|
|
@ -58,7 +58,7 @@ func main() {
|
||||||
|
|
||||||
accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{
|
accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{
|
||||||
ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString,
|
ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString,
|
||||||
}, cfg.Global.ServerName, bcrypt.DefaultCost)
|
}, cfg.Global.ServerName, bcrypt.DefaultCost, cfg.UserAPI.OpenIDTokenLifetimeMS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Fatalln("Failed to connect to the database:", err.Error())
|
logrus.Fatalln("Failed to connect to the database:", err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -360,6 +360,11 @@ user_api:
|
||||||
max_open_conns: 10
|
max_open_conns: 10
|
||||||
max_idle_conns: 2
|
max_idle_conns: 2
|
||||||
conn_max_lifetime: -1
|
conn_max_lifetime: -1
|
||||||
|
# The length of time that a token issued for a relying party from
|
||||||
|
# /_matrix/client/r0/user/{userId}/openid/request_token endpoint
|
||||||
|
# is considered to be valid in milliseconds.
|
||||||
|
# The default lifetime is 3600000ms (60 minutes).
|
||||||
|
# openid_token_lifetime_ms: 3600000
|
||||||
|
|
||||||
# Configuration for Opentracing.
|
# Configuration for Opentracing.
|
||||||
# See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on
|
# See https://github.com/matrix-org/dendrite/tree/master/docs/tracing for information on
|
||||||
|
|
65
federationapi/routing/openid.go
Normal file
65
federationapi/routing/openid.go
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openIDUserInfoResponse struct {
|
||||||
|
Sub string `json:"sub"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenIDUserInfo implements GET /_matrix/federation/v1/openid/userinfo
|
||||||
|
func GetOpenIDUserInfo(
|
||||||
|
httpReq *http.Request,
|
||||||
|
userAPI userapi.UserInternalAPI,
|
||||||
|
) util.JSONResponse {
|
||||||
|
token := httpReq.URL.Query().Get("access_token")
|
||||||
|
if len(token) == 0 {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusUnauthorized,
|
||||||
|
JSON: jsonerror.MissingArgument("access_token is missing"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
req := userapi.QueryOpenIDTokenRequest{
|
||||||
|
Token: token,
|
||||||
|
}
|
||||||
|
|
||||||
|
var openIDTokenAttrResponse userapi.QueryOpenIDTokenResponse
|
||||||
|
err := userAPI.QueryOpenIDToken(httpReq.Context(), &req, &openIDTokenAttrResponse)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(httpReq.Context()).WithError(err).Error("userAPI.QueryOpenIDToken failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
var res interface{} = openIDUserInfoResponse{Sub: openIDTokenAttrResponse.Sub}
|
||||||
|
code := http.StatusOK
|
||||||
|
nowMS := time.Now().UnixNano() / int64(time.Millisecond)
|
||||||
|
if openIDTokenAttrResponse.Sub == "" || nowMS > openIDTokenAttrResponse.ExpiresAtMS {
|
||||||
|
code = http.StatusUnauthorized
|
||||||
|
res = jsonerror.UnknownToken("Access Token unknown or expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: code,
|
||||||
|
JSON: res,
|
||||||
|
}
|
||||||
|
}
|
|
@ -462,4 +462,10 @@ func Setup(
|
||||||
return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName)
|
return QueryDeviceKeys(httpReq, request, keyAPI, cfg.Matrix.ServerName)
|
||||||
},
|
},
|
||||||
)).Methods(http.MethodPost)
|
)).Methods(http.MethodPost)
|
||||||
|
|
||||||
|
v1fedmux.Handle("/openid/userinfo",
|
||||||
|
httputil.MakeExternalAPI("federation_openid_userinfo", func(req *http.Request) util.JSONResponse {
|
||||||
|
return GetOpenIDUserInfo(req, userAPI)
|
||||||
|
}),
|
||||||
|
).Methods(http.MethodGet)
|
||||||
}
|
}
|
||||||
|
|
|
@ -280,7 +280,7 @@ func (b *BaseDendrite) KeyServerHTTPClient() keyserverAPI.KeyInternalAPI {
|
||||||
// CreateAccountsDB creates a new instance of the accounts database. Should only
|
// CreateAccountsDB creates a new instance of the accounts database. Should only
|
||||||
// be called once per component.
|
// be called once per component.
|
||||||
func (b *BaseDendrite) CreateAccountsDB() accounts.Database {
|
func (b *BaseDendrite) CreateAccountsDB() accounts.Database {
|
||||||
db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost)
|
db, err := accounts.NewDatabase(&b.Cfg.UserAPI.AccountDatabase, b.Cfg.Global.ServerName, b.Cfg.UserAPI.BCryptCost, b.Cfg.UserAPI.OpenIDTokenLifetimeMS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Panicf("failed to connect to accounts db")
|
logrus.WithError(err).Panicf("failed to connect to accounts db")
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,9 @@ type UserAPI struct {
|
||||||
// The cost when hashing passwords.
|
// The cost when hashing passwords.
|
||||||
BCryptCost int `yaml:"bcrypt_cost"`
|
BCryptCost int `yaml:"bcrypt_cost"`
|
||||||
|
|
||||||
|
// The length of time an OpenID token is condidered valid in milliseconds
|
||||||
|
OpenIDTokenLifetimeMS int64 `yaml:"openid_token_lifetime_ms"`
|
||||||
|
|
||||||
// The Account database stores the login details and account information
|
// The Account database stores the login details and account information
|
||||||
// for local users. It is accessed by the UserAPI.
|
// for local users. It is accessed by the UserAPI.
|
||||||
AccountDatabase DatabaseOptions `yaml:"account_database"`
|
AccountDatabase DatabaseOptions `yaml:"account_database"`
|
||||||
|
@ -18,6 +21,8 @@ type UserAPI struct {
|
||||||
DeviceDatabase DatabaseOptions `yaml:"device_database"`
|
DeviceDatabase DatabaseOptions `yaml:"device_database"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const DefaultOpenIDTokenLifetimeMS = 3600000 // 60 minutes
|
||||||
|
|
||||||
func (c *UserAPI) Defaults() {
|
func (c *UserAPI) Defaults() {
|
||||||
c.InternalAPI.Listen = "http://localhost:7781"
|
c.InternalAPI.Listen = "http://localhost:7781"
|
||||||
c.InternalAPI.Connect = "http://localhost:7781"
|
c.InternalAPI.Connect = "http://localhost:7781"
|
||||||
|
@ -26,6 +31,7 @@ func (c *UserAPI) Defaults() {
|
||||||
c.AccountDatabase.ConnectionString = "file:userapi_accounts.db"
|
c.AccountDatabase.ConnectionString = "file:userapi_accounts.db"
|
||||||
c.DeviceDatabase.ConnectionString = "file:userapi_devices.db"
|
c.DeviceDatabase.ConnectionString = "file:userapi_devices.db"
|
||||||
c.BCryptCost = bcrypt.DefaultCost
|
c.BCryptCost = bcrypt.DefaultCost
|
||||||
|
c.OpenIDTokenLifetimeMS = DefaultOpenIDTokenLifetimeMS
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
||||||
|
@ -33,4 +39,5 @@ func (c *UserAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
||||||
checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect))
|
checkURL(configErrs, "user_api.internal_api.connect", string(c.InternalAPI.Connect))
|
||||||
checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString))
|
checkNotEmpty(configErrs, "user_api.account_database.connection_string", string(c.AccountDatabase.ConnectionString))
|
||||||
checkNotEmpty(configErrs, "user_api.device_database.connection_string", string(c.DeviceDatabase.ConnectionString))
|
checkNotEmpty(configErrs, "user_api.device_database.connection_string", string(c.DeviceDatabase.ConnectionString))
|
||||||
|
checkPositive(configErrs, "user_api.openid_token_lifetime_ms", c.OpenIDTokenLifetimeMS)
|
||||||
}
|
}
|
||||||
|
|
|
@ -524,6 +524,9 @@ func (u *testUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.Pe
|
||||||
func (u *testUserAPI) PerformAccountDeactivation(ctx context.Context, req *userapi.PerformAccountDeactivationRequest, res *userapi.PerformAccountDeactivationResponse) error {
|
func (u *testUserAPI) PerformAccountDeactivation(ctx context.Context, req *userapi.PerformAccountDeactivationRequest, res *userapi.PerformAccountDeactivationResponse) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (u *testUserAPI) PerformOpenIDTokenCreation(ctx context.Context, req *userapi.PerformOpenIDTokenCreationRequest, res *userapi.PerformOpenIDTokenCreationResponse) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error {
|
func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -548,6 +551,9 @@ func (u *testUserAPI) QueryDeviceInfos(ctx context.Context, req *userapi.QueryDe
|
||||||
func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.QuerySearchProfilesRequest, res *userapi.QuerySearchProfilesResponse) error {
|
func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.QuerySearchProfilesRequest, res *userapi.QuerySearchProfilesResponse) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (u *testUserAPI) QueryOpenIDToken(ctx context.Context, req *userapi.QueryOpenIDTokenRequest, res *userapi.QueryOpenIDTokenResponse) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type testRoomserverAPI struct {
|
type testRoomserverAPI struct {
|
||||||
// use a trace API as it implements method stubs so we don't need to have them here.
|
// use a trace API as it implements method stubs so we don't need to have them here.
|
||||||
|
|
|
@ -367,6 +367,9 @@ func (u *testUserAPI) PerformLastSeenUpdate(ctx context.Context, req *userapi.Pe
|
||||||
func (u *testUserAPI) PerformAccountDeactivation(ctx context.Context, req *userapi.PerformAccountDeactivationRequest, res *userapi.PerformAccountDeactivationResponse) error {
|
func (u *testUserAPI) PerformAccountDeactivation(ctx context.Context, req *userapi.PerformAccountDeactivationRequest, res *userapi.PerformAccountDeactivationResponse) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (u *testUserAPI) PerformOpenIDTokenCreation(ctx context.Context, req *userapi.PerformOpenIDTokenCreationRequest, res *userapi.PerformOpenIDTokenCreationResponse) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error {
|
func (u *testUserAPI) QueryProfile(ctx context.Context, req *userapi.QueryProfileRequest, res *userapi.QueryProfileResponse) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -391,6 +394,9 @@ func (u *testUserAPI) QueryDeviceInfos(ctx context.Context, req *userapi.QueryDe
|
||||||
func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.QuerySearchProfilesRequest, res *userapi.QuerySearchProfilesResponse) error {
|
func (u *testUserAPI) QuerySearchProfiles(ctx context.Context, req *userapi.QuerySearchProfilesRequest, res *userapi.QuerySearchProfilesResponse) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
func (u *testUserAPI) QueryOpenIDToken(ctx context.Context, req *userapi.QueryOpenIDTokenRequest, res *userapi.QueryOpenIDTokenResponse) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type testRoomserverAPI struct {
|
type testRoomserverAPI struct {
|
||||||
// use a trace API as it implements method stubs so we don't need to have them here.
|
// use a trace API as it implements method stubs so we don't need to have them here.
|
||||||
|
|
|
@ -517,3 +517,6 @@ AS can set avatar for ghosted users
|
||||||
AS can set displayname for ghosted users
|
AS can set displayname for ghosted users
|
||||||
Ghost user must register before joining room
|
Ghost user must register before joining room
|
||||||
Inviting an AS-hosted user asks the AS server
|
Inviting an AS-hosted user asks the AS server
|
||||||
|
Can generate a openid access_token that can be exchanged for information about a user
|
||||||
|
Invalid openid access tokens are rejected
|
||||||
|
Requests to userinfo without access tokens are rejected
|
||||||
|
|
|
@ -32,12 +32,14 @@ type UserInternalAPI interface {
|
||||||
PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error
|
PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error
|
||||||
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
|
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
|
||||||
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
|
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
|
||||||
|
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
|
||||||
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
||||||
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
|
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
|
||||||
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
||||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||||
QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
|
QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
|
||||||
QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error
|
QuerySearchProfiles(ctx context.Context, req *QuerySearchProfilesRequest, res *QuerySearchProfilesResponse) error
|
||||||
|
QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// InputAccountDataRequest is the request for InputAccountData
|
// InputAccountDataRequest is the request for InputAccountData
|
||||||
|
@ -226,6 +228,27 @@ type PerformAccountDeactivationResponse struct {
|
||||||
AccountDeactivated bool
|
AccountDeactivated bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PerformOpenIDTokenCreationRequest is the request for PerformOpenIDTokenCreation
|
||||||
|
type PerformOpenIDTokenCreationRequest struct {
|
||||||
|
UserID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// PerformOpenIDTokenCreationResponse is the response for PerformOpenIDTokenCreation
|
||||||
|
type PerformOpenIDTokenCreationResponse struct {
|
||||||
|
Token OpenIDToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryOpenIDTokenRequest is the request for QueryOpenIDToken
|
||||||
|
type QueryOpenIDTokenRequest struct {
|
||||||
|
Token string
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryOpenIDTokenResponse is the response for QueryOpenIDToken
|
||||||
|
type QueryOpenIDTokenResponse struct {
|
||||||
|
Sub string // The Matrix User ID that generated the token
|
||||||
|
ExpiresAtMS int64
|
||||||
|
}
|
||||||
|
|
||||||
// Device represents a client's device (mobile, web, etc)
|
// Device represents a client's device (mobile, web, etc)
|
||||||
type Device struct {
|
type Device struct {
|
||||||
ID string
|
ID string
|
||||||
|
@ -256,6 +279,24 @@ type Account struct {
|
||||||
// TODO: Associations (e.g. with application services)
|
// TODO: Associations (e.g. with application services)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenIDToken represents an OpenID token
|
||||||
|
type OpenIDToken struct {
|
||||||
|
Token string
|
||||||
|
UserID string
|
||||||
|
ExpiresAtMS int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenIDTokenInfo represents the attributes associated with an issued OpenID token
|
||||||
|
type OpenIDTokenAttributes struct {
|
||||||
|
UserID string
|
||||||
|
ExpiresAtMS int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserInfo is for returning information about the user an OpenID token was issued for
|
||||||
|
type UserInfo struct {
|
||||||
|
Sub string // The Matrix user's ID who generated the token
|
||||||
|
}
|
||||||
|
|
||||||
// ErrorForbidden is an error indicating that the supplied access token is forbidden
|
// ErrorForbidden is an error indicating that the supplied access token is forbidden
|
||||||
type ErrorForbidden struct {
|
type ErrorForbidden struct {
|
||||||
Message string
|
Message string
|
||||||
|
|
|
@ -414,3 +414,31 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
|
||||||
res.AccountDeactivated = err == nil
|
res.AccountDeactivated = err == nil
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PerformOpenIDTokenCreation creates a new token that a relying party uses to authenticate a user
|
||||||
|
func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error {
|
||||||
|
token := util.RandomString(24)
|
||||||
|
|
||||||
|
exp, err := a.AccountDB.CreateOpenIDToken(ctx, token, req.UserID)
|
||||||
|
|
||||||
|
res.Token = api.OpenIDToken{
|
||||||
|
Token: token,
|
||||||
|
UserID: req.UserID,
|
||||||
|
ExpiresAtMS: exp,
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation
|
||||||
|
func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
|
||||||
|
openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, req.Token)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
res.Sub = openIDTokenAttrs.UserID
|
||||||
|
res.ExpiresAtMS = openIDTokenAttrs.ExpiresAtMS
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -35,6 +35,7 @@ const (
|
||||||
PerformLastSeenUpdatePath = "/userapi/performLastSeenUpdate"
|
PerformLastSeenUpdatePath = "/userapi/performLastSeenUpdate"
|
||||||
PerformDeviceUpdatePath = "/userapi/performDeviceUpdate"
|
PerformDeviceUpdatePath = "/userapi/performDeviceUpdate"
|
||||||
PerformAccountDeactivationPath = "/userapi/performAccountDeactivation"
|
PerformAccountDeactivationPath = "/userapi/performAccountDeactivation"
|
||||||
|
PerformOpenIDTokenCreationPath = "/userapi/performOpenIDTokenCreation"
|
||||||
|
|
||||||
QueryProfilePath = "/userapi/queryProfile"
|
QueryProfilePath = "/userapi/queryProfile"
|
||||||
QueryAccessTokenPath = "/userapi/queryAccessToken"
|
QueryAccessTokenPath = "/userapi/queryAccessToken"
|
||||||
|
@ -42,6 +43,7 @@ const (
|
||||||
QueryAccountDataPath = "/userapi/queryAccountData"
|
QueryAccountDataPath = "/userapi/queryAccountData"
|
||||||
QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
|
QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
|
||||||
QuerySearchProfilesPath = "/userapi/querySearchProfiles"
|
QuerySearchProfilesPath = "/userapi/querySearchProfiles"
|
||||||
|
QueryOpenIDTokenPath = "/userapi/queryOpenIDToken"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
||||||
|
@ -148,6 +150,14 @@ func (h *httpUserInternalAPI) PerformAccountDeactivation(ctx context.Context, re
|
||||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, request *api.PerformOpenIDTokenCreationRequest, response *api.PerformOpenIDTokenCreationResponse) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformOpenIDTokenCreation")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + PerformOpenIDTokenCreationPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *httpUserInternalAPI) QueryProfile(
|
func (h *httpUserInternalAPI) QueryProfile(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *api.QueryProfileRequest,
|
request *api.QueryProfileRequest,
|
||||||
|
@ -207,3 +217,11 @@ func (h *httpUserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.
|
||||||
apiURL := h.apiURL + QuerySearchProfilesPath
|
apiURL := h.apiURL + QuerySearchProfilesPath
|
||||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *httpUserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryOpenIDToken")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.apiURL + QueryOpenIDTokenPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
||||||
|
|
|
@ -117,6 +117,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
internalAPIMux.Handle(PerformOpenIDTokenCreationPath,
|
||||||
|
httputil.MakeInternalAPI("performOpenIDTokenCreation", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.PerformOpenIDTokenCreationRequest{}
|
||||||
|
response := api.PerformOpenIDTokenCreationResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := s.PerformOpenIDTokenCreation(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
internalAPIMux.Handle(QueryProfilePath,
|
internalAPIMux.Handle(QueryProfilePath,
|
||||||
httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse {
|
httputil.MakeInternalAPI("queryProfile", func(req *http.Request) util.JSONResponse {
|
||||||
request := api.QueryProfileRequest{}
|
request := api.QueryProfileRequest{}
|
||||||
|
@ -195,6 +208,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
internalAPIMux.Handle(QueryOpenIDTokenPath,
|
||||||
|
httputil.MakeInternalAPI("queryOpenIDToken", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QueryOpenIDTokenRequest{}
|
||||||
|
response := api.QueryOpenIDTokenResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := s.QueryOpenIDToken(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
internalAPIMux.Handle(InputAccountDataPath,
|
internalAPIMux.Handle(InputAccountDataPath,
|
||||||
httputil.MakeInternalAPI("inputAccountDataPath", func(req *http.Request) util.JSONResponse {
|
httputil.MakeInternalAPI("inputAccountDataPath", func(req *http.Request) util.JSONResponse {
|
||||||
request := api.InputAccountDataRequest{}
|
request := api.InputAccountDataRequest{}
|
||||||
|
|
|
@ -52,6 +52,8 @@ type Database interface {
|
||||||
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
||||||
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
||||||
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
||||||
|
CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error)
|
||||||
|
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Err3PIDInUse is the error returned when trying to save an association involving
|
// Err3PIDInUse is the error returned when trying to save an association involving
|
||||||
|
|
84
userapi/storage/accounts/postgres/openid_table.go
Normal file
84
userapi/storage/accounts/postgres/openid_table.go
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const openIDTokenSchema = `
|
||||||
|
-- Stores data about openid tokens issued for accounts.
|
||||||
|
CREATE TABLE IF NOT EXISTS open_id_tokens (
|
||||||
|
-- The value of the token issued to a user
|
||||||
|
token TEXT NOT NULL PRIMARY KEY,
|
||||||
|
-- The Matrix user ID for this account
|
||||||
|
localpart TEXT NOT NULL,
|
||||||
|
-- When the token expires, as a unix timestamp (ms resolution).
|
||||||
|
token_expires_at_ms BIGINT NOT NULL
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
const insertTokenSQL = "" +
|
||||||
|
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||||
|
|
||||||
|
const selectTokenSQL = "" +
|
||||||
|
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
|
||||||
|
|
||||||
|
type tokenStatements struct {
|
||||||
|
insertTokenStmt *sql.Stmt
|
||||||
|
selectTokenStmt *sql.Stmt
|
||||||
|
serverName gomatrixserverlib.ServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||||
|
_, err = db.Exec(openIDTokenSchema)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.serverName = server
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// insertToken inserts a new OpenID Connect token to the DB.
|
||||||
|
// Returns new token, otherwise returns error if the token already exists.
|
||||||
|
func (s *tokenStatements) insertToken(
|
||||||
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
token, localpart string,
|
||||||
|
expiresAtMS int64,
|
||||||
|
) (err error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
||||||
|
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
|
||||||
|
// Returns the existing token's attributes, or err if no token is found
|
||||||
|
func (s *tokenStatements) selectOpenIDTokenAtrributes(
|
||||||
|
ctx context.Context,
|
||||||
|
token string,
|
||||||
|
) (*api.OpenIDTokenAttributes, error) {
|
||||||
|
var openIDTokenAttrs api.OpenIDTokenAttributes
|
||||||
|
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
||||||
|
&openIDTokenAttrs.UserID,
|
||||||
|
&openIDTokenAttrs.ExpiresAtMS,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if err != sql.ErrNoRows {
|
||||||
|
log.WithError(err).Error("Unable to retrieve token from the db")
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &openIDTokenAttrs, nil
|
||||||
|
}
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
@ -39,25 +40,28 @@ type Database struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.Writer
|
writer sqlutil.Writer
|
||||||
sqlutil.PartitionOffsetStatements
|
sqlutil.PartitionOffsetStatements
|
||||||
accounts accountsStatements
|
accounts accountsStatements
|
||||||
profiles profilesStatements
|
profiles profilesStatements
|
||||||
accountDatas accountDataStatements
|
accountDatas accountDataStatements
|
||||||
threepids threepidStatements
|
threepids threepidStatements
|
||||||
serverName gomatrixserverlib.ServerName
|
openIDTokens tokenStatements
|
||||||
bcryptCost int
|
serverName gomatrixserverlib.ServerName
|
||||||
|
bcryptCost int
|
||||||
|
openIDTokenLifetimeMS int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabase creates a new accounts and profiles database
|
// NewDatabase creates a new accounts and profiles database
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int) (*Database, error) {
|
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
|
||||||
db, err := sqlutil.Open(dbProperties)
|
db, err := sqlutil.Open(dbProperties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
d := &Database{
|
d := &Database{
|
||||||
serverName: serverName,
|
serverName: serverName,
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewDummyWriter(),
|
writer: sqlutil.NewDummyWriter(),
|
||||||
bcryptCost: bcryptCost,
|
bcryptCost: bcryptCost,
|
||||||
|
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
// Create tables before executing migrations so we don't fail if the table is missing,
|
||||||
|
@ -86,6 +90,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
||||||
if err = d.threepids.prepare(db); err != nil {
|
if err = d.threepids.prepare(db); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if err = d.openIDTokens.prepare(db, serverName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
@ -341,3 +348,23 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi
|
||||||
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
|
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
|
||||||
return d.accounts.deactivateAccount(ctx, localpart)
|
return d.accounts.deactivateAccount(ctx, localpart)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateOpenIDToken persists a new token that was issued through OpenID Connect
|
||||||
|
func (d *Database) CreateOpenIDToken(
|
||||||
|
ctx context.Context,
|
||||||
|
token, localpart string,
|
||||||
|
) (int64, error) {
|
||||||
|
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
|
||||||
|
err := sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||||
|
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
|
||||||
|
})
|
||||||
|
return expiresAtMS, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
|
||||||
|
func (d *Database) GetOpenIDTokenAttributes(
|
||||||
|
ctx context.Context,
|
||||||
|
token string,
|
||||||
|
) (*api.OpenIDTokenAttributes, error) {
|
||||||
|
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
|
||||||
|
}
|
||||||
|
|
86
userapi/storage/accounts/sqlite3/openid_table.go
Normal file
86
userapi/storage/accounts/sqlite3/openid_table.go
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const openIDTokenSchema = `
|
||||||
|
-- Stores data about accounts.
|
||||||
|
CREATE TABLE IF NOT EXISTS open_id_tokens (
|
||||||
|
-- The value of the token issued to a user
|
||||||
|
token TEXT NOT NULL PRIMARY KEY,
|
||||||
|
-- The Matrix user ID for this account
|
||||||
|
localpart TEXT NOT NULL,
|
||||||
|
-- When the token expires, as a unix timestamp (ms resolution).
|
||||||
|
token_expires_at_ms BIGINT NOT NULL
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
const insertTokenSQL = "" +
|
||||||
|
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||||
|
|
||||||
|
const selectTokenSQL = "" +
|
||||||
|
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
|
||||||
|
|
||||||
|
type tokenStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
insertTokenStmt *sql.Stmt
|
||||||
|
selectTokenStmt *sql.Stmt
|
||||||
|
serverName gomatrixserverlib.ServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||||
|
s.db = db
|
||||||
|
_, err = db.Exec(openIDTokenSchema)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if s.insertTokenStmt, err = db.Prepare(insertTokenSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectTokenStmt, err = db.Prepare(selectTokenSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.serverName = server
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// insertToken inserts a new OpenID Connect token to the DB.
|
||||||
|
// Returns new token, otherwise returns error if the token already exists.
|
||||||
|
func (s *tokenStatements) insertToken(
|
||||||
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
token, localpart string,
|
||||||
|
expiresAtMS int64,
|
||||||
|
) (err error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
||||||
|
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
|
||||||
|
// Returns the existing token's attributes, or err if no token is found
|
||||||
|
func (s *tokenStatements) selectOpenIDTokenAtrributes(
|
||||||
|
ctx context.Context,
|
||||||
|
token string,
|
||||||
|
) (*api.OpenIDTokenAttributes, error) {
|
||||||
|
var openIDTokenAttrs api.OpenIDTokenAttributes
|
||||||
|
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
||||||
|
&openIDTokenAttrs.UserID,
|
||||||
|
&openIDTokenAttrs.ExpiresAtMS,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if err != sql.ErrNoRows {
|
||||||
|
log.WithError(err).Error("Unable to retrieve token from the db")
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &openIDTokenAttrs, nil
|
||||||
|
}
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
@ -37,12 +38,14 @@ type Database struct {
|
||||||
writer sqlutil.Writer
|
writer sqlutil.Writer
|
||||||
|
|
||||||
sqlutil.PartitionOffsetStatements
|
sqlutil.PartitionOffsetStatements
|
||||||
accounts accountsStatements
|
accounts accountsStatements
|
||||||
profiles profilesStatements
|
profiles profilesStatements
|
||||||
accountDatas accountDataStatements
|
accountDatas accountDataStatements
|
||||||
threepids threepidStatements
|
threepids threepidStatements
|
||||||
serverName gomatrixserverlib.ServerName
|
openIDTokens tokenStatements
|
||||||
bcryptCost int
|
serverName gomatrixserverlib.ServerName
|
||||||
|
bcryptCost int
|
||||||
|
openIDTokenLifetimeMS int64
|
||||||
|
|
||||||
accountsMu sync.Mutex
|
accountsMu sync.Mutex
|
||||||
profilesMu sync.Mutex
|
profilesMu sync.Mutex
|
||||||
|
@ -51,16 +54,17 @@ type Database struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabase creates a new accounts and profiles database
|
// NewDatabase creates a new accounts and profiles database
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int) (*Database, error) {
|
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
|
||||||
db, err := sqlutil.Open(dbProperties)
|
db, err := sqlutil.Open(dbProperties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
d := &Database{
|
d := &Database{
|
||||||
serverName: serverName,
|
serverName: serverName,
|
||||||
db: db,
|
db: db,
|
||||||
writer: sqlutil.NewExclusiveWriter(),
|
writer: sqlutil.NewExclusiveWriter(),
|
||||||
bcryptCost: bcryptCost,
|
bcryptCost: bcryptCost,
|
||||||
|
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
// Create tables before executing migrations so we don't fail if the table is missing,
|
||||||
|
@ -90,6 +94,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
||||||
if err = d.threepids.prepare(db); err != nil {
|
if err = d.threepids.prepare(db); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if err = d.openIDTokens.prepare(db, serverName); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
@ -379,3 +386,23 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi
|
||||||
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
|
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
|
||||||
return d.accounts.deactivateAccount(ctx, localpart)
|
return d.accounts.deactivateAccount(ctx, localpart)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateOpenIDToken persists a new token that was issued for OpenID Connect
|
||||||
|
func (d *Database) CreateOpenIDToken(
|
||||||
|
ctx context.Context,
|
||||||
|
token, localpart string,
|
||||||
|
) (int64, error) {
|
||||||
|
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
|
||||||
|
err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
|
||||||
|
})
|
||||||
|
return expiresAtMS, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
|
||||||
|
func (d *Database) GetOpenIDTokenAttributes(
|
||||||
|
ctx context.Context,
|
||||||
|
token string,
|
||||||
|
) (*api.OpenIDTokenAttributes, error) {
|
||||||
|
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
|
||||||
|
}
|
||||||
|
|
|
@ -27,12 +27,12 @@ import (
|
||||||
|
|
||||||
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
|
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
|
||||||
// and sets postgres connection parameters
|
// and sets postgres connection parameters
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int) (Database, error) {
|
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (Database, error) {
|
||||||
switch {
|
switch {
|
||||||
case dbProperties.ConnectionString.IsSQLite():
|
case dbProperties.ConnectionString.IsSQLite():
|
||||||
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost)
|
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
|
||||||
case dbProperties.ConnectionString.IsPostgres():
|
case dbProperties.ConnectionString.IsPostgres():
|
||||||
return postgres.NewDatabase(dbProperties, serverName, bcryptCost)
|
return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unexpected database type")
|
return nil, fmt.Errorf("unexpected database type")
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,10 +26,11 @@ func NewDatabase(
|
||||||
dbProperties *config.DatabaseOptions,
|
dbProperties *config.DatabaseOptions,
|
||||||
serverName gomatrixserverlib.ServerName,
|
serverName gomatrixserverlib.ServerName,
|
||||||
bcryptCost int,
|
bcryptCost int,
|
||||||
|
openIDTokenLifetimeMS int64,
|
||||||
) (Database, error) {
|
) (Database, error) {
|
||||||
switch {
|
switch {
|
||||||
case dbProperties.ConnectionString.IsSQLite():
|
case dbProperties.ConnectionString.IsSQLite():
|
||||||
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost)
|
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
|
||||||
case dbProperties.ConnectionString.IsPostgres():
|
case dbProperties.ConnectionString.IsPostgres():
|
||||||
return nil, fmt.Errorf("can't use Postgres implementation")
|
return nil, fmt.Errorf("can't use Postgres implementation")
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -26,7 +26,7 @@ const (
|
||||||
func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database) {
|
func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database) {
|
||||||
accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{
|
accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{
|
||||||
ConnectionString: "file::memory:",
|
ConnectionString: "file::memory:",
|
||||||
}, serverName, bcrypt.MinCost)
|
}, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create account DB: %s", err)
|
t.Fatalf("failed to create account DB: %s", err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue