Handle inbound federation E2E key queries/claims (#1215)

* Handle inbound /keys/claim and /keys/query requests

* Add display names to device key responses

* Linting
This commit is contained in:
Kegsay 2020-07-22 17:04:57 +01:00 committed by GitHub
parent 1e71fd645e
commit 541a23f712
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 321 additions and 35 deletions

View file

@ -24,6 +24,7 @@ type Database interface {
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,

View file

@ -84,11 +84,15 @@ const deleteDevicesByLocalpartSQL = "" +
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id = ANY($1)"
type devicesStatements struct {
insertDeviceStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
selectDeviceByIDStmt *sql.Stmt
selectDevicesByLocalpartStmt *sql.Stmt
selectDevicesByIDStmt *sql.Stmt
updateDeviceNameStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt
deleteDevicesByLocalpartStmt *sql.Stmt
@ -125,6 +129,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
return
}
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
return
}
s.serverName = server
return
}
@ -207,15 +214,42 @@ func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
var dev api.Device
var displayName sql.NullString
stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName)
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName)
if err == nil {
dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
if displayName.Valid {
dev.DisplayName = displayName.String
}
}
return &dev, err
}
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
var devices []api.Device
for rows.Next() {
var dev api.Device
var localpart string
var displayName sql.NullString
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
return nil, err
}
if displayName.Valid {
dev.DisplayName = displayName.String
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
devices = append(devices, dev)
}
return devices, rows.Err()
}
func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {

View file

@ -71,6 +71,10 @@ func (d *Database) GetDevicesByLocalpart(
return d.devices.selectDevicesByLocalpart(ctx, localpart)
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.devices.selectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,

View file

@ -20,6 +20,7 @@ import (
"strings"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
@ -72,6 +73,9 @@ const deleteDevicesByLocalpartSQL = "" +
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
type devicesStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
@ -79,6 +83,7 @@ type devicesStatements struct {
selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
selectDeviceByIDStmt *sql.Stmt
selectDevicesByIDStmt *sql.Stmt
selectDevicesByLocalpartStmt *sql.Stmt
updateDeviceNameStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt
@ -117,6 +122,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return
}
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
return
}
s.serverName = server
return
}
@ -224,11 +232,15 @@ func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
var dev api.Device
var displayName sql.NullString
stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName)
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName)
if err == nil {
dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
if displayName.Valid {
dev.DisplayName = displayName.String
}
}
return &dev, err
}
@ -263,3 +275,32 @@ func (s *devicesStatements) selectDevicesByLocalpart(
return devices, nil
}
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1)
iDeviceIDs := make([]interface{}, len(deviceIDs))
for i := range deviceIDs {
iDeviceIDs[i] = deviceIDs[i]
}
rows, err := s.db.QueryContext(ctx, sqlQuery, iDeviceIDs...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
var devices []api.Device
for rows.Next() {
var dev api.Device
var localpart string
var displayName sql.NullString
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
return nil, err
}
if displayName.Valid {
dev.DisplayName = displayName.String
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
devices = append(devices, dev)
}
return devices, rows.Err()
}

View file

@ -77,6 +77,10 @@ func (d *Database) GetDevicesByLocalpart(
return d.devices.selectDevicesByLocalpart(ctx, localpart)
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.devices.selectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,