mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-07-30 21:12:45 +00:00
feat: admin APIs for token authenticated registration (#3101)
### Pull Request Checklist <!-- Please read https://matrix-org.github.io/dendrite/development/contributing before submitting your pull request --> * [x] I have added Go unit tests or [Complement integration tests](https://github.com/matrix-org/complement) for this PR _or_ I have justified why this PR doesn't need tests * [x] Pull request includes a [sign off below using a legally identifiable name](https://matrix-org.github.io/dendrite/development/contributing#sign-off) _or_ I have already signed off privately Signed-off-by: `Santhoshivan Amudhan santhoshivan23@gmail.com`
This commit is contained in:
parent
a734b112c6
commit
45082d4dce
14 changed files with 1474 additions and 1 deletions
|
@ -27,6 +27,7 @@ import (
|
|||
"github.com/matrix-org/gomatrixserverlib/fclient"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
|
||||
clientapi "github.com/matrix-org/dendrite/clientapi/api"
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||
)
|
||||
|
@ -94,6 +95,11 @@ type ClientUserAPI interface {
|
|||
QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error
|
||||
QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error)
|
||||
QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error
|
||||
PerformAdminCreateRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error)
|
||||
PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
|
||||
PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error)
|
||||
PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error
|
||||
PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error)
|
||||
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
|
||||
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
|
||||
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
|
||||
|
|
|
@ -33,6 +33,7 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
clientapi "github.com/matrix-org/dendrite/clientapi/api"
|
||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||
|
@ -63,6 +64,37 @@ type UserInternalAPI struct {
|
|||
Updater *DeviceListUpdater
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformAdminCreateRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error) {
|
||||
exists, err := a.DB.RegistrationTokenExists(ctx, *registrationToken.Token)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if exists {
|
||||
return false, fmt.Errorf("token: %s already exists", *registrationToken.Token)
|
||||
}
|
||||
_, err = a.DB.InsertRegistrationToken(ctx, registrationToken)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("Error creating token: %s"+err.Error(), *registrationToken.Token)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) {
|
||||
return a.DB.ListRegistrationTokens(ctx, returnAll, valid)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) {
|
||||
return a.DB.GetRegistrationToken(ctx, tokenString)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error {
|
||||
return a.DB.DeleteRegistrationToken(ctx, tokenString)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) {
|
||||
return a.DB.UpdateRegistrationToken(ctx, tokenString, newAttributes)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
|
||||
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||
if err != nil {
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
"github.com/matrix-org/gomatrixserverlib/fclient"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
|
||||
clientapi "github.com/matrix-org/dendrite/clientapi/api"
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
|
@ -30,6 +31,15 @@ import (
|
|||
"github.com/matrix-org/dendrite/userapi/types"
|
||||
)
|
||||
|
||||
type RegistrationTokens interface {
|
||||
RegistrationTokenExists(ctx context.Context, token string) (bool, error)
|
||||
InsertRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error)
|
||||
ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
|
||||
GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error)
|
||||
DeleteRegistrationToken(ctx context.Context, tokenString string) error
|
||||
UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error)
|
||||
}
|
||||
|
||||
type Profile interface {
|
||||
GetProfileByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (*authtypes.Profile, error)
|
||||
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
||||
|
@ -144,6 +154,7 @@ type UserDatabase interface {
|
|||
Pusher
|
||||
Statistics
|
||||
ThreePID
|
||||
RegistrationTokens
|
||||
}
|
||||
|
||||
type KeyChangeDatabase interface {
|
||||
|
|
222
userapi/storage/postgres/registration_tokens_table.go
Normal file
222
userapi/storage/postgres/registration_tokens_table.go
Normal file
|
@ -0,0 +1,222 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/api"
|
||||
internal "github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
const registrationTokensSchema = `
|
||||
CREATE TABLE IF NOT EXISTS userapi_registration_tokens (
|
||||
token TEXT PRIMARY KEY,
|
||||
pending BIGINT,
|
||||
completed BIGINT,
|
||||
uses_allowed BIGINT,
|
||||
expiry_time BIGINT
|
||||
);
|
||||
`
|
||||
|
||||
const selectTokenSQL = "" +
|
||||
"SELECT token FROM userapi_registration_tokens WHERE token = $1"
|
||||
|
||||
const insertTokenSQL = "" +
|
||||
"INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)"
|
||||
|
||||
const listAllTokensSQL = "" +
|
||||
"SELECT * FROM userapi_registration_tokens"
|
||||
|
||||
const listValidTokensSQL = "" +
|
||||
"SELECT * FROM userapi_registration_tokens WHERE" +
|
||||
"(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" +
|
||||
"(expiry_time > $1 OR expiry_time IS NULL)"
|
||||
|
||||
const listInvalidTokensSQL = "" +
|
||||
"SELECT * FROM userapi_registration_tokens WHERE" +
|
||||
"(uses_allowed <= pending + completed OR expiry_time <= $1)"
|
||||
|
||||
const getTokenSQL = "" +
|
||||
"SELECT pending, completed, uses_allowed, expiry_time FROM userapi_registration_tokens WHERE token = $1"
|
||||
|
||||
const deleteTokenSQL = "" +
|
||||
"DELETE FROM userapi_registration_tokens WHERE token = $1"
|
||||
|
||||
const updateTokenUsesAllowedAndExpiryTimeSQL = "" +
|
||||
"UPDATE userapi_registration_tokens SET uses_allowed = $2, expiry_time = $3 WHERE token = $1"
|
||||
|
||||
const updateTokenUsesAllowedSQL = "" +
|
||||
"UPDATE userapi_registration_tokens SET uses_allowed = $2 WHERE token = $1"
|
||||
|
||||
const updateTokenExpiryTimeSQL = "" +
|
||||
"UPDATE userapi_registration_tokens SET expiry_time = $2 WHERE token = $1"
|
||||
|
||||
type registrationTokenStatements struct {
|
||||
selectTokenStatement *sql.Stmt
|
||||
insertTokenStatement *sql.Stmt
|
||||
listAllTokensStatement *sql.Stmt
|
||||
listValidTokensStatement *sql.Stmt
|
||||
listInvalidTokenStatement *sql.Stmt
|
||||
getTokenStatement *sql.Stmt
|
||||
deleteTokenStatement *sql.Stmt
|
||||
updateTokenUsesAllowedAndExpiryTimeStatement *sql.Stmt
|
||||
updateTokenUsesAllowedStatement *sql.Stmt
|
||||
updateTokenExpiryTimeStatement *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) {
|
||||
s := ®istrationTokenStatements{}
|
||||
_, err := db.Exec(registrationTokensSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.selectTokenStatement, selectTokenSQL},
|
||||
{&s.insertTokenStatement, insertTokenSQL},
|
||||
{&s.listAllTokensStatement, listAllTokensSQL},
|
||||
{&s.listValidTokensStatement, listValidTokensSQL},
|
||||
{&s.listInvalidTokenStatement, listInvalidTokensSQL},
|
||||
{&s.getTokenStatement, getTokenSQL},
|
||||
{&s.deleteTokenStatement, deleteTokenSQL},
|
||||
{&s.updateTokenUsesAllowedAndExpiryTimeStatement, updateTokenUsesAllowedAndExpiryTimeSQL},
|
||||
{&s.updateTokenUsesAllowedStatement, updateTokenUsesAllowedSQL},
|
||||
{&s.updateTokenExpiryTimeStatement, updateTokenExpiryTimeSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) {
|
||||
var existingToken string
|
||||
stmt := sqlutil.TxStmt(tx, s.selectTokenStatement)
|
||||
err := stmt.QueryRowContext(ctx, token).Scan(&existingToken)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, registrationToken *api.RegistrationToken) (bool, error) {
|
||||
stmt := sqlutil.TxStmt(tx, s.insertTokenStatement)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
*registrationToken.Token,
|
||||
getInsertValue(registrationToken.UsesAllowed),
|
||||
getInsertValue(registrationToken.ExpiryTime),
|
||||
*registrationToken.Pending,
|
||||
*registrationToken.Completed)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func getInsertValue[t constraints.Integer](in *t) any {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
return *in
|
||||
}
|
||||
|
||||
func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {
|
||||
var stmt *sql.Stmt
|
||||
var tokens []api.RegistrationToken
|
||||
var tokenString string
|
||||
var pending, completed, usesAllowed *int32
|
||||
var expiryTime *int64
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if returnAll {
|
||||
stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement)
|
||||
rows, err = stmt.QueryContext(ctx)
|
||||
} else if valid {
|
||||
stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement)
|
||||
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
|
||||
} else {
|
||||
stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement)
|
||||
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
|
||||
}
|
||||
if err != nil {
|
||||
return tokens, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed")
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime)
|
||||
if err != nil {
|
||||
return tokens, err
|
||||
}
|
||||
tokenString := tokenString
|
||||
pending := pending
|
||||
completed := completed
|
||||
usesAllowed := usesAllowed
|
||||
expiryTime := expiryTime
|
||||
|
||||
tokenMap := api.RegistrationToken{
|
||||
Token: &tokenString,
|
||||
Pending: pending,
|
||||
Completed: completed,
|
||||
UsesAllowed: usesAllowed,
|
||||
ExpiryTime: expiryTime,
|
||||
}
|
||||
tokens = append(tokens, tokenMap)
|
||||
}
|
||||
return tokens, rows.Err()
|
||||
}
|
||||
|
||||
func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) {
|
||||
stmt := sqlutil.TxStmt(tx, s.getTokenStatement)
|
||||
var pending, completed, usesAllowed *int32
|
||||
var expiryTime *int64
|
||||
err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token := api.RegistrationToken{
|
||||
Token: &tokenString,
|
||||
Pending: pending,
|
||||
Completed: completed,
|
||||
UsesAllowed: usesAllowed,
|
||||
ExpiryTime: expiryTime,
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error {
|
||||
stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement)
|
||||
_, err := stmt.ExecContext(ctx, tokenString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *registrationTokenStatements) UpdateRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*api.RegistrationToken, error) {
|
||||
var stmt *sql.Stmt
|
||||
usesAllowed, usesAllowedPresent := newAttributes["usesAllowed"]
|
||||
expiryTime, expiryTimePresent := newAttributes["expiryTime"]
|
||||
if usesAllowedPresent && expiryTimePresent {
|
||||
stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedAndExpiryTimeStatement)
|
||||
_, err := stmt.ExecContext(ctx, tokenString, usesAllowed, expiryTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if usesAllowedPresent {
|
||||
stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedStatement)
|
||||
_, err := stmt.ExecContext(ctx, tokenString, usesAllowed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if expiryTimePresent {
|
||||
stmt = sqlutil.TxStmt(tx, s.updateTokenExpiryTimeStatement)
|
||||
_, err := stmt.ExecContext(ctx, tokenString, expiryTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return s.GetRegistrationToken(ctx, tx, tokenString)
|
||||
}
|
|
@ -53,6 +53,10 @@ func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *
|
|||
return nil, err
|
||||
}
|
||||
|
||||
registationTokensTable, err := NewPostgresRegistrationTokensTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresRegistrationsTokenTable: %w", err)
|
||||
}
|
||||
accountsTable, err := NewPostgresAccountsTable(db, serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err)
|
||||
|
@ -125,6 +129,7 @@ func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties *
|
|||
ThreePIDs: threePIDTable,
|
||||
Pushers: pusherTable,
|
||||
Notifications: notificationsTable,
|
||||
RegistrationTokens: registationTokensTable,
|
||||
Stats: statsTable,
|
||||
ServerName: serverName,
|
||||
DB: db,
|
||||
|
|
|
@ -31,6 +31,7 @@ import (
|
|||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
clientapi "github.com/matrix-org/dendrite/clientapi/api"
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/internal/pushrules"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
|
@ -43,6 +44,7 @@ import (
|
|||
type Database struct {
|
||||
DB *sql.DB
|
||||
Writer sqlutil.Writer
|
||||
RegistrationTokens tables.RegistrationTokensTable
|
||||
Accounts tables.AccountsTable
|
||||
Profiles tables.ProfileTable
|
||||
AccountDatas tables.AccountDataTable
|
||||
|
@ -78,6 +80,42 @@ const (
|
|||
loginTokenByteLength = 32
|
||||
)
|
||||
|
||||
func (d *Database) RegistrationTokenExists(ctx context.Context, token string) (bool, error) {
|
||||
return d.RegistrationTokens.RegistrationTokenExists(ctx, nil, token)
|
||||
}
|
||||
|
||||
func (d *Database) InsertRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (created bool, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
created, err = d.RegistrationTokens.InsertRegistrationToken(ctx, txn, registrationToken)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) {
|
||||
return d.RegistrationTokens.ListRegistrationTokens(ctx, nil, returnAll, valid)
|
||||
}
|
||||
|
||||
func (d *Database) GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) {
|
||||
return d.RegistrationTokens.GetRegistrationToken(ctx, nil, tokenString)
|
||||
}
|
||||
|
||||
func (d *Database) DeleteRegistrationToken(ctx context.Context, tokenString string) (err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
err = d.RegistrationTokens.DeleteRegistrationToken(ctx, nil, tokenString)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (updatedToken *clientapi.RegistrationToken, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
updatedToken, err = d.RegistrationTokens.UpdateRegistrationToken(ctx, txn, tokenString, newAttributes)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||
func (d *Database) GetAccountByPassword(
|
||||
|
|
222
userapi/storage/sqlite3/registration_tokens_table.go
Normal file
222
userapi/storage/sqlite3/registration_tokens_table.go
Normal file
|
@ -0,0 +1,222 @@
|
|||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/api"
|
||||
internal "github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
const registrationTokensSchema = `
|
||||
CREATE TABLE IF NOT EXISTS userapi_registration_tokens (
|
||||
token TEXT PRIMARY KEY,
|
||||
pending BIGINT,
|
||||
completed BIGINT,
|
||||
uses_allowed BIGINT,
|
||||
expiry_time BIGINT
|
||||
);
|
||||
`
|
||||
|
||||
const selectTokenSQL = "" +
|
||||
"SELECT token FROM userapi_registration_tokens WHERE token = $1"
|
||||
|
||||
const insertTokenSQL = "" +
|
||||
"INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)"
|
||||
|
||||
const listAllTokensSQL = "" +
|
||||
"SELECT * FROM userapi_registration_tokens"
|
||||
|
||||
const listValidTokensSQL = "" +
|
||||
"SELECT * FROM userapi_registration_tokens WHERE" +
|
||||
"(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" +
|
||||
"(expiry_time > $1 OR expiry_time IS NULL)"
|
||||
|
||||
const listInvalidTokensSQL = "" +
|
||||
"SELECT * FROM userapi_registration_tokens WHERE" +
|
||||
"(uses_allowed <= pending + completed OR expiry_time <= $1)"
|
||||
|
||||
const getTokenSQL = "" +
|
||||
"SELECT pending, completed, uses_allowed, expiry_time FROM userapi_registration_tokens WHERE token = $1"
|
||||
|
||||
const deleteTokenSQL = "" +
|
||||
"DELETE FROM userapi_registration_tokens WHERE token = $1"
|
||||
|
||||
const updateTokenUsesAllowedAndExpiryTimeSQL = "" +
|
||||
"UPDATE userapi_registration_tokens SET uses_allowed = $2, expiry_time = $3 WHERE token = $1"
|
||||
|
||||
const updateTokenUsesAllowedSQL = "" +
|
||||
"UPDATE userapi_registration_tokens SET uses_allowed = $2 WHERE token = $1"
|
||||
|
||||
const updateTokenExpiryTimeSQL = "" +
|
||||
"UPDATE userapi_registration_tokens SET expiry_time = $2 WHERE token = $1"
|
||||
|
||||
type registrationTokenStatements struct {
|
||||
selectTokenStatement *sql.Stmt
|
||||
insertTokenStatement *sql.Stmt
|
||||
listAllTokensStatement *sql.Stmt
|
||||
listValidTokensStatement *sql.Stmt
|
||||
listInvalidTokenStatement *sql.Stmt
|
||||
getTokenStatement *sql.Stmt
|
||||
deleteTokenStatement *sql.Stmt
|
||||
updateTokenUsesAllowedAndExpiryTimeStatement *sql.Stmt
|
||||
updateTokenUsesAllowedStatement *sql.Stmt
|
||||
updateTokenExpiryTimeStatement *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSQLiteRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) {
|
||||
s := ®istrationTokenStatements{}
|
||||
_, err := db.Exec(registrationTokensSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.selectTokenStatement, selectTokenSQL},
|
||||
{&s.insertTokenStatement, insertTokenSQL},
|
||||
{&s.listAllTokensStatement, listAllTokensSQL},
|
||||
{&s.listValidTokensStatement, listValidTokensSQL},
|
||||
{&s.listInvalidTokenStatement, listInvalidTokensSQL},
|
||||
{&s.getTokenStatement, getTokenSQL},
|
||||
{&s.deleteTokenStatement, deleteTokenSQL},
|
||||
{&s.updateTokenUsesAllowedAndExpiryTimeStatement, updateTokenUsesAllowedAndExpiryTimeSQL},
|
||||
{&s.updateTokenUsesAllowedStatement, updateTokenUsesAllowedSQL},
|
||||
{&s.updateTokenExpiryTimeStatement, updateTokenExpiryTimeSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) {
|
||||
var existingToken string
|
||||
stmt := sqlutil.TxStmt(tx, s.selectTokenStatement)
|
||||
err := stmt.QueryRowContext(ctx, token).Scan(&existingToken)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, registrationToken *api.RegistrationToken) (bool, error) {
|
||||
stmt := sqlutil.TxStmt(tx, s.insertTokenStatement)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
*registrationToken.Token,
|
||||
getInsertValue(registrationToken.UsesAllowed),
|
||||
getInsertValue(registrationToken.ExpiryTime),
|
||||
*registrationToken.Pending,
|
||||
*registrationToken.Completed)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func getInsertValue[t constraints.Integer](in *t) any {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
return *in
|
||||
}
|
||||
|
||||
func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) {
|
||||
var stmt *sql.Stmt
|
||||
var tokens []api.RegistrationToken
|
||||
var tokenString string
|
||||
var pending, completed, usesAllowed *int32
|
||||
var expiryTime *int64
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if returnAll {
|
||||
stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement)
|
||||
rows, err = stmt.QueryContext(ctx)
|
||||
} else if valid {
|
||||
stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement)
|
||||
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
|
||||
} else {
|
||||
stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement)
|
||||
rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond))
|
||||
}
|
||||
if err != nil {
|
||||
return tokens, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed")
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime)
|
||||
if err != nil {
|
||||
return tokens, err
|
||||
}
|
||||
tokenString := tokenString
|
||||
pending := pending
|
||||
completed := completed
|
||||
usesAllowed := usesAllowed
|
||||
expiryTime := expiryTime
|
||||
|
||||
tokenMap := api.RegistrationToken{
|
||||
Token: &tokenString,
|
||||
Pending: pending,
|
||||
Completed: completed,
|
||||
UsesAllowed: usesAllowed,
|
||||
ExpiryTime: expiryTime,
|
||||
}
|
||||
tokens = append(tokens, tokenMap)
|
||||
}
|
||||
return tokens, rows.Err()
|
||||
}
|
||||
|
||||
func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) {
|
||||
stmt := sqlutil.TxStmt(tx, s.getTokenStatement)
|
||||
var pending, completed, usesAllowed *int32
|
||||
var expiryTime *int64
|
||||
err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token := api.RegistrationToken{
|
||||
Token: &tokenString,
|
||||
Pending: pending,
|
||||
Completed: completed,
|
||||
UsesAllowed: usesAllowed,
|
||||
ExpiryTime: expiryTime,
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error {
|
||||
stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement)
|
||||
_, err := stmt.ExecContext(ctx, tokenString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *registrationTokenStatements) UpdateRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*api.RegistrationToken, error) {
|
||||
var stmt *sql.Stmt
|
||||
usesAllowed, usesAllowedPresent := newAttributes["usesAllowed"]
|
||||
expiryTime, expiryTimePresent := newAttributes["expiryTime"]
|
||||
if usesAllowedPresent && expiryTimePresent {
|
||||
stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedAndExpiryTimeStatement)
|
||||
_, err := stmt.ExecContext(ctx, tokenString, usesAllowed, expiryTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if usesAllowedPresent {
|
||||
stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedStatement)
|
||||
_, err := stmt.ExecContext(ctx, tokenString, usesAllowed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if expiryTimePresent {
|
||||
stmt = sqlutil.TxStmt(tx, s.updateTokenExpiryTimeStatement)
|
||||
_, err := stmt.ExecContext(ctx, tokenString, expiryTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return s.GetRegistrationToken(ctx, tx, tokenString)
|
||||
}
|
|
@ -50,7 +50,10 @@ func NewUserDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperti
|
|||
if err = m.Up(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
registationTokensTable, err := NewSQLiteRegistrationTokensTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteRegistrationsTokenTable: %w", err)
|
||||
}
|
||||
accountsTable, err := NewSQLiteAccountsTable(db, serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err)
|
||||
|
@ -130,6 +133,7 @@ func NewUserDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperti
|
|||
LoginTokenLifetime: loginTokenLifetime,
|
||||
BcryptCost: bcryptCost,
|
||||
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||
RegistrationTokens: registationTokensTable,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -25,10 +25,20 @@ import (
|
|||
"github.com/matrix-org/gomatrixserverlib/fclient"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
|
||||
clientapi "github.com/matrix-org/dendrite/clientapi/api"
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/userapi/types"
|
||||
)
|
||||
|
||||
type RegistrationTokensTable interface {
|
||||
RegistrationTokenExists(ctx context.Context, txn *sql.Tx, token string) (bool, error)
|
||||
InsertRegistrationToken(ctx context.Context, txn *sql.Tx, registrationToken *clientapi.RegistrationToken) (bool, error)
|
||||
ListRegistrationTokens(ctx context.Context, txn *sql.Tx, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error)
|
||||
GetRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) (*clientapi.RegistrationToken, error)
|
||||
DeleteRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) error
|
||||
UpdateRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error)
|
||||
}
|
||||
|
||||
type AccountDataTable interface {
|
||||
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage) error
|
||||
SelectAccountData(ctx context.Context, localpart string, serverName spec.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue