mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-07-31 13:22:46 +00:00
Make userapi responsible for checking access tokens (#1133)
* Make userapi responsible for checking access tokens There's still plenty of dependencies on account/device DBs, but this is a start. This is a breaking change as it adds a required config value `listen.user_api`. * Cleanup * Review comments and test fix
This commit is contained in:
parent
57b7fa3db8
commit
9c77022513
66 changed files with 421 additions and 400 deletions
|
@ -17,14 +17,14 @@ package devices
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
type Database interface {
|
||||
GetDeviceByAccessToken(ctx context.Context, token string) (*authtypes.Device, error)
|
||||
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*authtypes.Device, error)
|
||||
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]authtypes.Device, error)
|
||||
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *authtypes.Device, returnErr error)
|
||||
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
|
||||
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
|
||||
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
|
||||
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *api.Device, returnErr error)
|
||||
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
|
||||
RemoveDevice(ctx context.Context, deviceID, localpart string) error
|
||||
RemoveDevices(ctx context.Context, localpart string, devices []string) error
|
||||
|
|
|
@ -20,10 +20,10 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
|
@ -135,14 +135,14 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
|
|||
func (s *devicesStatements) insertDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
||||
displayName *string,
|
||||
) (*authtypes.Device, error) {
|
||||
) (*api.Device, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
var sessionID int64
|
||||
stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
|
||||
if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName).Scan(&sessionID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &authtypes.Device{
|
||||
return &api.Device{
|
||||
ID: id,
|
||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||
AccessToken: accessToken,
|
||||
|
@ -189,8 +189,8 @@ func (s *devicesStatements) updateDeviceName(
|
|||
|
||||
func (s *devicesStatements) selectDeviceByToken(
|
||||
ctx context.Context, accessToken string,
|
||||
) (*authtypes.Device, error) {
|
||||
var dev authtypes.Device
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
stmt := s.selectDeviceByTokenStmt
|
||||
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
|
||||
|
@ -205,8 +205,8 @@ func (s *devicesStatements) selectDeviceByToken(
|
|||
// localpart and deviceID
|
||||
func (s *devicesStatements) selectDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
) (*authtypes.Device, error) {
|
||||
var dev authtypes.Device
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
stmt := s.selectDeviceByIDStmt
|
||||
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName)
|
||||
if err == nil {
|
||||
|
@ -218,8 +218,8 @@ func (s *devicesStatements) selectDeviceByID(
|
|||
|
||||
func (s *devicesStatements) selectDevicesByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) ([]authtypes.Device, error) {
|
||||
devices := []authtypes.Device{}
|
||||
) ([]api.Device, error) {
|
||||
devices := []api.Device{}
|
||||
|
||||
rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart)
|
||||
|
||||
|
@ -229,7 +229,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
|
|||
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByLocalpart: rows.close() failed")
|
||||
|
||||
for rows.Next() {
|
||||
var dev authtypes.Device
|
||||
var dev api.Device
|
||||
var id, displayname sql.NullString
|
||||
err = rows.Scan(&id, &displayname)
|
||||
if err != nil {
|
||||
|
|
|
@ -20,8 +20,8 @@ import (
|
|||
"database/sql"
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
|
@ -52,7 +52,7 @@ func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, serve
|
|||
// Returns sql.ErrNoRows if no matching device was found.
|
||||
func (d *Database) GetDeviceByAccessToken(
|
||||
ctx context.Context, token string,
|
||||
) (*authtypes.Device, error) {
|
||||
) (*api.Device, error) {
|
||||
return d.devices.selectDeviceByToken(ctx, token)
|
||||
}
|
||||
|
||||
|
@ -60,14 +60,14 @@ func (d *Database) GetDeviceByAccessToken(
|
|||
// Returns sql.ErrNoRows if no matching device was found.
|
||||
func (d *Database) GetDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
) (*authtypes.Device, error) {
|
||||
) (*api.Device, error) {
|
||||
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
|
||||
}
|
||||
|
||||
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
||||
func (d *Database) GetDevicesByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) ([]authtypes.Device, error) {
|
||||
) ([]api.Device, error) {
|
||||
return d.devices.selectDevicesByLocalpart(ctx, localpart)
|
||||
}
|
||||
|
||||
|
@ -80,7 +80,7 @@ func (d *Database) GetDevicesByLocalpart(
|
|||
func (d *Database) CreateDevice(
|
||||
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
||||
displayName *string,
|
||||
) (dev *authtypes.Device, returnErr error) {
|
||||
) (dev *api.Device, returnErr error) {
|
||||
if deviceID != nil {
|
||||
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
|
|
|
@ -21,8 +21,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
@ -125,7 +125,7 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
|
|||
func (s *devicesStatements) insertDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
||||
displayName *string,
|
||||
) (*authtypes.Device, error) {
|
||||
) (*api.Device, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
var sessionID int64
|
||||
countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
|
||||
|
@ -137,7 +137,7 @@ func (s *devicesStatements) insertDevice(
|
|||
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &authtypes.Device{
|
||||
return &api.Device{
|
||||
ID: id,
|
||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||
AccessToken: accessToken,
|
||||
|
@ -190,8 +190,8 @@ func (s *devicesStatements) updateDeviceName(
|
|||
|
||||
func (s *devicesStatements) selectDeviceByToken(
|
||||
ctx context.Context, accessToken string,
|
||||
) (*authtypes.Device, error) {
|
||||
var dev authtypes.Device
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
stmt := s.selectDeviceByTokenStmt
|
||||
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
|
||||
|
@ -206,8 +206,8 @@ func (s *devicesStatements) selectDeviceByToken(
|
|||
// localpart and deviceID
|
||||
func (s *devicesStatements) selectDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
) (*authtypes.Device, error) {
|
||||
var dev authtypes.Device
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
stmt := s.selectDeviceByIDStmt
|
||||
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName)
|
||||
if err == nil {
|
||||
|
@ -219,8 +219,8 @@ func (s *devicesStatements) selectDeviceByID(
|
|||
|
||||
func (s *devicesStatements) selectDevicesByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) ([]authtypes.Device, error) {
|
||||
devices := []authtypes.Device{}
|
||||
) ([]api.Device, error) {
|
||||
devices := []api.Device{}
|
||||
|
||||
rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart)
|
||||
|
||||
|
@ -229,7 +229,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
|
|||
}
|
||||
|
||||
for rows.Next() {
|
||||
var dev authtypes.Device
|
||||
var dev api.Device
|
||||
var id, displayname sql.NullString
|
||||
err = rows.Scan(&id, &displayname)
|
||||
if err != nil {
|
||||
|
|
|
@ -20,8 +20,8 @@ import (
|
|||
"database/sql"
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
|
@ -58,7 +58,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
|
|||
// Returns sql.ErrNoRows if no matching device was found.
|
||||
func (d *Database) GetDeviceByAccessToken(
|
||||
ctx context.Context, token string,
|
||||
) (*authtypes.Device, error) {
|
||||
) (*api.Device, error) {
|
||||
return d.devices.selectDeviceByToken(ctx, token)
|
||||
}
|
||||
|
||||
|
@ -66,14 +66,14 @@ func (d *Database) GetDeviceByAccessToken(
|
|||
// Returns sql.ErrNoRows if no matching device was found.
|
||||
func (d *Database) GetDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
) (*authtypes.Device, error) {
|
||||
) (*api.Device, error) {
|
||||
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
|
||||
}
|
||||
|
||||
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
||||
func (d *Database) GetDevicesByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) ([]authtypes.Device, error) {
|
||||
) ([]api.Device, error) {
|
||||
return d.devices.selectDevicesByLocalpart(ctx, localpart)
|
||||
}
|
||||
|
||||
|
@ -86,7 +86,7 @@ func (d *Database) GetDevicesByLocalpart(
|
|||
func (d *Database) CreateDevice(
|
||||
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
||||
displayName *string,
|
||||
) (dev *authtypes.Device, returnErr error) {
|
||||
) (dev *api.Device, returnErr error) {
|
||||
if deviceID != nil {
|
||||
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue