mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-07-30 04:52:46 +00:00
Handle inbound federation E2E key queries/claims (#1215)
* Handle inbound /keys/claim and /keys/query requests * Add display names to device key responses * Linting
This commit is contained in:
parent
1e71fd645e
commit
541a23f712
25 changed files with 321 additions and 35 deletions
|
@ -30,6 +30,7 @@ type UserInternalAPI interface {
|
|||
QueryAccessToken(ctx context.Context, req *QueryAccessTokenRequest, res *QueryAccessTokenResponse) error
|
||||
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||
QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
|
||||
}
|
||||
|
||||
// InputAccountDataRequest is the request for InputAccountData
|
||||
|
@ -44,6 +45,19 @@ type InputAccountDataRequest struct {
|
|||
type InputAccountDataResponse struct {
|
||||
}
|
||||
|
||||
// QueryDeviceInfosRequest is the request to QueryDeviceInfos
|
||||
type QueryDeviceInfosRequest struct {
|
||||
DeviceIDs []string
|
||||
}
|
||||
|
||||
// QueryDeviceInfosResponse is the response to QueryDeviceInfos
|
||||
type QueryDeviceInfosResponse struct {
|
||||
DeviceInfo map[string]struct {
|
||||
DisplayName string
|
||||
UserID string
|
||||
}
|
||||
}
|
||||
|
||||
// QueryAccessTokenRequest is the request for QueryAccessToken
|
||||
type QueryAccessTokenRequest struct {
|
||||
AccessToken string
|
||||
|
|
|
@ -125,6 +125,27 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
|
|||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error {
|
||||
devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.DeviceInfo = make(map[string]struct {
|
||||
DisplayName string
|
||||
UserID string
|
||||
})
|
||||
for _, d := range devices {
|
||||
res.DeviceInfo[d.ID] = struct {
|
||||
DisplayName string
|
||||
UserID string
|
||||
}{
|
||||
DisplayName: d.DisplayName,
|
||||
UserID: d.UserID,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevicesRequest, res *api.QueryDevicesResponse) error {
|
||||
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||
if err != nil {
|
||||
|
|
|
@ -35,6 +35,7 @@ const (
|
|||
QueryAccessTokenPath = "/userapi/queryAccessToken"
|
||||
QueryDevicesPath = "/userapi/queryDevices"
|
||||
QueryAccountDataPath = "/userapi/queryAccountData"
|
||||
QueryDeviceInfosPath = "/userapi/queryDeviceInfos"
|
||||
)
|
||||
|
||||
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
|
||||
|
@ -101,6 +102,18 @@ func (h *httpUserInternalAPI) QueryProfile(
|
|||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
}
|
||||
|
||||
func (h *httpUserInternalAPI) QueryDeviceInfos(
|
||||
ctx context.Context,
|
||||
request *api.QueryDeviceInfosRequest,
|
||||
response *api.QueryDeviceInfosResponse,
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryDeviceInfos")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.apiURL + QueryDeviceInfosPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
}
|
||||
|
||||
func (h *httpUserInternalAPI) QueryAccessToken(
|
||||
ctx context.Context,
|
||||
request *api.QueryAccessTokenRequest,
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
// nolint: gocyclo
|
||||
func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
||||
internalAPIMux.Handle(PerformAccountCreationPath,
|
||||
httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse {
|
||||
|
@ -103,4 +104,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
|
|||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(QueryDeviceInfosPath,
|
||||
httputil.MakeInternalAPI("queryDeviceInfos", func(req *http.Request) util.JSONResponse {
|
||||
request := api.QueryDeviceInfosRequest{}
|
||||
response := api.QueryDeviceInfosResponse{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := s.QueryDeviceInfos(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ type Database interface {
|
|||
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
|
||||
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
|
||||
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
|
||||
GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
|
||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
||||
|
|
|
@ -84,11 +84,15 @@ const deleteDevicesByLocalpartSQL = "" +
|
|||
const deleteDevicesSQL = "" +
|
||||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
|
||||
|
||||
const selectDevicesByIDSQL = "" +
|
||||
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id = ANY($1)"
|
||||
|
||||
type devicesStatements struct {
|
||||
insertDeviceStmt *sql.Stmt
|
||||
selectDeviceByTokenStmt *sql.Stmt
|
||||
selectDeviceByIDStmt *sql.Stmt
|
||||
selectDevicesByLocalpartStmt *sql.Stmt
|
||||
selectDevicesByIDStmt *sql.Stmt
|
||||
updateDeviceNameStmt *sql.Stmt
|
||||
deleteDeviceStmt *sql.Stmt
|
||||
deleteDevicesByLocalpartStmt *sql.Stmt
|
||||
|
@ -125,6 +129,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
|
|||
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
s.serverName = server
|
||||
return
|
||||
}
|
||||
|
@ -207,15 +214,42 @@ func (s *devicesStatements) selectDeviceByID(
|
|||
ctx context.Context, localpart, deviceID string,
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
var displayName sql.NullString
|
||||
stmt := s.selectDeviceByIDStmt
|
||||
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName)
|
||||
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName)
|
||||
if err == nil {
|
||||
dev.ID = deviceID
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
if displayName.Valid {
|
||||
dev.DisplayName = displayName.String
|
||||
}
|
||||
}
|
||||
return &dev, err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||
rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
|
||||
var devices []api.Device
|
||||
for rows.Next() {
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
var displayName sql.NullString
|
||||
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if displayName.Valid {
|
||||
dev.DisplayName = displayName.String
|
||||
}
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
devices = append(devices, dev)
|
||||
}
|
||||
return devices, rows.Err()
|
||||
}
|
||||
|
||||
func (s *devicesStatements) selectDevicesByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
) ([]api.Device, error) {
|
||||
|
|
|
@ -71,6 +71,10 @@ func (d *Database) GetDevicesByLocalpart(
|
|||
return d.devices.selectDevicesByLocalpart(ctx, localpart)
|
||||
}
|
||||
|
||||
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||
return d.devices.selectDevicesByID(ctx, deviceIDs)
|
||||
}
|
||||
|
||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
|
||||
|
@ -72,6 +73,9 @@ const deleteDevicesByLocalpartSQL = "" +
|
|||
const deleteDevicesSQL = "" +
|
||||
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
|
||||
|
||||
const selectDevicesByIDSQL = "" +
|
||||
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)"
|
||||
|
||||
type devicesStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
|
@ -79,6 +83,7 @@ type devicesStatements struct {
|
|||
selectDevicesCountStmt *sql.Stmt
|
||||
selectDeviceByTokenStmt *sql.Stmt
|
||||
selectDeviceByIDStmt *sql.Stmt
|
||||
selectDevicesByIDStmt *sql.Stmt
|
||||
selectDevicesByLocalpartStmt *sql.Stmt
|
||||
updateDeviceNameStmt *sql.Stmt
|
||||
deleteDeviceStmt *sql.Stmt
|
||||
|
@ -117,6 +122,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
|
|||
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
s.serverName = server
|
||||
return
|
||||
}
|
||||
|
@ -224,11 +232,15 @@ func (s *devicesStatements) selectDeviceByID(
|
|||
ctx context.Context, localpart, deviceID string,
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
var displayName sql.NullString
|
||||
stmt := s.selectDeviceByIDStmt
|
||||
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName)
|
||||
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName)
|
||||
if err == nil {
|
||||
dev.ID = deviceID
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
if displayName.Valid {
|
||||
dev.DisplayName = displayName.String
|
||||
}
|
||||
}
|
||||
return &dev, err
|
||||
}
|
||||
|
@ -263,3 +275,32 @@ func (s *devicesStatements) selectDevicesByLocalpart(
|
|||
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||
sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1)
|
||||
iDeviceIDs := make([]interface{}, len(deviceIDs))
|
||||
for i := range deviceIDs {
|
||||
iDeviceIDs[i] = deviceIDs[i]
|
||||
}
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, sqlQuery, iDeviceIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
|
||||
var devices []api.Device
|
||||
for rows.Next() {
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
var displayName sql.NullString
|
||||
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if displayName.Valid {
|
||||
dev.DisplayName = displayName.String
|
||||
}
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
devices = append(devices, dev)
|
||||
}
|
||||
return devices, rows.Err()
|
||||
}
|
||||
|
|
|
@ -77,6 +77,10 @@ func (d *Database) GetDevicesByLocalpart(
|
|||
return d.devices.selectDevicesByLocalpart(ctx, localpart)
|
||||
}
|
||||
|
||||
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||
return d.devices.selectDevicesByID(ctx, deviceIDs)
|
||||
}
|
||||
|
||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue