Make userapi responsible for checking access tokens (#1133)

* Make userapi responsible for checking access tokens

There's still plenty of dependencies on account/device DBs, but this
is a start. This is a breaking change as it adds a required config
value `listen.user_api`.

* Cleanup

* Review comments and test fix
This commit is contained in:
Kegsay 2020-06-16 14:10:55 +01:00 committed by GitHub
parent 57b7fa3db8
commit 9c77022513
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
66 changed files with 421 additions and 400 deletions

View file

@ -17,14 +17,14 @@ package devices
import (
"context"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/api"
)
type Database interface {
GetDeviceByAccessToken(ctx context.Context, token string) (*authtypes.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*authtypes.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]authtypes.Device, error)
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *authtypes.Device, returnErr error)
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *api.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error

View file

@ -20,10 +20,10 @@ import (
"time"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
@ -135,14 +135,14 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
func (s *devicesStatements) insertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string,
) (*authtypes.Device, error) {
) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName).Scan(&sessionID); err != nil {
return nil, err
}
return &authtypes.Device{
return &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, s.serverName),
AccessToken: accessToken,
@ -189,8 +189,8 @@ func (s *devicesStatements) updateDeviceName(
func (s *devicesStatements) selectDeviceByToken(
ctx context.Context, accessToken string,
) (*authtypes.Device, error) {
var dev authtypes.Device
) (*api.Device, error) {
var dev api.Device
var localpart string
stmt := s.selectDeviceByTokenStmt
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
@ -205,8 +205,8 @@ func (s *devicesStatements) selectDeviceByToken(
// localpart and deviceID
func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*authtypes.Device, error) {
var dev authtypes.Device
) (*api.Device, error) {
var dev api.Device
stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName)
if err == nil {
@ -218,8 +218,8 @@ func (s *devicesStatements) selectDeviceByID(
func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]authtypes.Device, error) {
devices := []authtypes.Device{}
) ([]api.Device, error) {
devices := []api.Device{}
rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart)
@ -229,7 +229,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByLocalpart: rows.close() failed")
for rows.Next() {
var dev authtypes.Device
var dev api.Device
var id, displayname sql.NullString
err = rows.Scan(&id, &displayname)
if err != nil {

View file

@ -20,8 +20,8 @@ import (
"database/sql"
"encoding/base64"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
@ -52,7 +52,7 @@ func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, serve
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*authtypes.Device, error) {
) (*api.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
@ -60,14 +60,14 @@ func (d *Database) GetDeviceByAccessToken(
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*authtypes.Device, error) {
) (*api.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]authtypes.Device, error) {
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, localpart)
}
@ -80,7 +80,7 @@ func (d *Database) GetDevicesByLocalpart(
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string,
) (dev *authtypes.Device, returnErr error) {
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error

View file

@ -21,8 +21,8 @@ import (
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/gomatrixserverlib"
)
@ -125,7 +125,7 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
func (s *devicesStatements) insertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string,
) (*authtypes.Device, error) {
) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
@ -137,7 +137,7 @@ func (s *devicesStatements) insertDevice(
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
return nil, err
}
return &authtypes.Device{
return &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, s.serverName),
AccessToken: accessToken,
@ -190,8 +190,8 @@ func (s *devicesStatements) updateDeviceName(
func (s *devicesStatements) selectDeviceByToken(
ctx context.Context, accessToken string,
) (*authtypes.Device, error) {
var dev authtypes.Device
) (*api.Device, error) {
var dev api.Device
var localpart string
stmt := s.selectDeviceByTokenStmt
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
@ -206,8 +206,8 @@ func (s *devicesStatements) selectDeviceByToken(
// localpart and deviceID
func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*authtypes.Device, error) {
var dev authtypes.Device
) (*api.Device, error) {
var dev api.Device
stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName)
if err == nil {
@ -219,8 +219,8 @@ func (s *devicesStatements) selectDeviceByID(
func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]authtypes.Device, error) {
devices := []authtypes.Device{}
) ([]api.Device, error) {
devices := []api.Device{}
rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart)
@ -229,7 +229,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
}
for rows.Next() {
var dev authtypes.Device
var dev api.Device
var id, displayname sql.NullString
err = rows.Scan(&id, &displayname)
if err != nil {

View file

@ -20,8 +20,8 @@ import (
"database/sql"
"encoding/base64"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
_ "github.com/mattn/go-sqlite3"
@ -58,7 +58,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*authtypes.Device, error) {
) (*api.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
@ -66,14 +66,14 @@ func (d *Database) GetDeviceByAccessToken(
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*authtypes.Device, error) {
) (*api.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]authtypes.Device, error) {
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, localpart)
}
@ -86,7 +86,7 @@ func (d *Database) GetDevicesByLocalpart(
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string,
) (dev *authtypes.Device, returnErr error) {
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error