Pass a context when downloading remote media (#251)

This commit is contained in:
Mark Haines 2017-09-21 16:20:10 +01:00 committed by GitHub
parent fef290c47e
commit ce019738ff
6 changed files with 65 additions and 20 deletions

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/common/config" "github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/mediaapi/routing" "github.com/matrix-org/dendrite/mediaapi/routing"
"github.com/matrix-org/dendrite/mediaapi/storage" "github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
) )
@ -51,10 +52,12 @@ func main() {
log.WithError(err).Panic("Failed to open database") log.WithError(err).Panic("Failed to open database")
} }
client := gomatrixserverlib.NewClient()
log.Info("Starting media API server on ", cfg.Listen.MediaAPI) log.Info("Starting media API server on ", cfg.Listen.MediaAPI)
api := mux.NewRouter() api := mux.NewRouter()
routing.Setup(api, cfg, db) routing.Setup(api, cfg, db, client)
common.SetupHTTPAPI(http.DefaultServeMux, api) common.SetupHTTPAPI(http.DefaultServeMux, api)
log.Fatal(http.ListenAndServe(string(cfg.Listen.MediaAPI), nil)) log.Fatal(http.ListenAndServe(string(cfg.Listen.MediaAPI), nil))

View file

@ -325,7 +325,7 @@ func (m *monolith) setupAPIs() {
) )
mediaapi_routing.Setup( mediaapi_routing.Setup(
m.api, m.cfg, m.mediaAPIDB, m.api, m.cfg, m.mediaAPIDB, &m.federation.Client,
) )
syncapi_routing.Setup(m.api, syncapi_sync.NewRequestPool( syncapi_routing.Setup(m.api, syncapi_sync.NewRequestPool(

View file

@ -31,7 +31,12 @@ import (
const pathPrefixR0 = "/_matrix/media/v1" const pathPrefixR0 = "/_matrix/media/v1"
// Setup registers the media API HTTP handlers // Setup registers the media API HTTP handlers
func Setup(apiMux *mux.Router, cfg *config.Dendrite, db *storage.Database) { func Setup(
apiMux *mux.Router,
cfg *config.Dendrite,
db *storage.Database,
client *gomatrixserverlib.Client,
) {
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter() r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
activeThumbnailGeneration := &types.ActiveThumbnailGeneration{ activeThumbnailGeneration := &types.ActiveThumbnailGeneration{
@ -47,14 +52,21 @@ func Setup(apiMux *mux.Router, cfg *config.Dendrite, db *storage.Database) {
MXCToResult: map[string]*types.RemoteRequestResult{}, MXCToResult: map[string]*types.RemoteRequestResult{},
} }
r0mux.Handle("/download/{serverName}/{mediaId}", r0mux.Handle("/download/{serverName}/{mediaId}",
makeDownloadAPI("download", cfg, db, activeRemoteRequests, activeThumbnailGeneration), makeDownloadAPI("download", cfg, db, client, activeRemoteRequests, activeThumbnailGeneration),
).Methods("GET") ).Methods("GET")
r0mux.Handle("/thumbnail/{serverName}/{mediaId}", r0mux.Handle("/thumbnail/{serverName}/{mediaId}",
makeDownloadAPI("thumbnail", cfg, db, activeRemoteRequests, activeThumbnailGeneration), makeDownloadAPI("thumbnail", cfg, db, client, activeRemoteRequests, activeThumbnailGeneration),
).Methods("GET") ).Methods("GET")
} }
func makeDownloadAPI(name string, cfg *config.Dendrite, db *storage.Database, activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration) http.HandlerFunc { func makeDownloadAPI(
name string,
cfg *config.Dendrite,
db *storage.Database,
client *gomatrixserverlib.Client,
activeRemoteRequests *types.ActiveRemoteRequests,
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
) http.HandlerFunc {
return prometheus.InstrumentHandler(name, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { return prometheus.InstrumentHandler(name, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req = util.RequestWithLogging(req) req = util.RequestWithLogging(req)
@ -64,6 +76,17 @@ func makeDownloadAPI(name string, cfg *config.Dendrite, db *storage.Database, ac
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
vars := mux.Vars(req) vars := mux.Vars(req)
writers.Download(w, req, gomatrixserverlib.ServerName(vars["serverName"]), types.MediaID(vars["mediaId"]), cfg, db, activeRemoteRequests, activeThumbnailGeneration, name == "thumbnail") writers.Download(
w,
req,
gomatrixserverlib.ServerName(vars["serverName"]),
types.MediaID(vars["mediaId"]),
cfg,
db,
client,
activeRemoteRequests,
activeThumbnailGeneration,
name == "thumbnail",
)
})) }))
} }

View file

@ -68,6 +68,7 @@ func Download(
mediaID types.MediaID, mediaID types.MediaID,
cfg *config.Dendrite, cfg *config.Dendrite,
db *storage.Database, db *storage.Database,
client *gomatrixserverlib.Client,
activeRemoteRequests *types.ActiveRemoteRequests, activeRemoteRequests *types.ActiveRemoteRequests,
activeThumbnailGeneration *types.ActiveThumbnailGeneration, activeThumbnailGeneration *types.ActiveThumbnailGeneration,
isThumbnailRequest bool, isThumbnailRequest bool,
@ -120,7 +121,8 @@ func Download(
} }
metadata, err := dReq.doDownload( metadata, err := dReq.doDownload(
req.Context(), w, cfg, db, activeRemoteRequests, activeThumbnailGeneration, req.Context(), w, cfg, db, client,
activeRemoteRequests, activeThumbnailGeneration,
) )
if err != nil { if err != nil {
// TODO: Handle the fact we might have started writing the response // TODO: Handle the fact we might have started writing the response
@ -199,6 +201,7 @@ func (r *downloadRequest) doDownload(
w http.ResponseWriter, w http.ResponseWriter,
cfg *config.Dendrite, cfg *config.Dendrite,
db *storage.Database, db *storage.Database,
client *gomatrixserverlib.Client,
activeRemoteRequests *types.ActiveRemoteRequests, activeRemoteRequests *types.ActiveRemoteRequests,
activeThumbnailGeneration *types.ActiveThumbnailGeneration, activeThumbnailGeneration *types.ActiveThumbnailGeneration,
) (*types.MediaMetadata, error) { ) (*types.MediaMetadata, error) {
@ -216,7 +219,7 @@ func (r *downloadRequest) doDownload(
} }
// If we do not have a record and the origin is remote, we need to fetch it and respond with that file // If we do not have a record and the origin is remote, we need to fetch it and respond with that file
resErr := r.getRemoteFile( resErr := r.getRemoteFile(
ctx, cfg, db, activeRemoteRequests, activeThumbnailGeneration, ctx, client, cfg, db, activeRemoteRequests, activeThumbnailGeneration,
) )
if resErr != nil { if resErr != nil {
return nil, resErr return nil, resErr
@ -442,6 +445,7 @@ func (r *downloadRequest) generateThumbnail(
// Note: The named errorResponse return variable is used in a deferred broadcast of the metadata and error response to waiting goroutines. // Note: The named errorResponse return variable is used in a deferred broadcast of the metadata and error response to waiting goroutines.
func (r *downloadRequest) getRemoteFile( func (r *downloadRequest) getRemoteFile(
ctx context.Context, ctx context.Context,
client *gomatrixserverlib.Client,
cfg *config.Dendrite, cfg *config.Dendrite,
db *storage.Database, db *storage.Database,
activeRemoteRequests *types.ActiveRemoteRequests, activeRemoteRequests *types.ActiveRemoteRequests,
@ -477,7 +481,8 @@ func (r *downloadRequest) getRemoteFile(
if mediaMetadata == nil { if mediaMetadata == nil {
// If we do not have a record, we need to fetch the remote file first and then respond from the local file // If we do not have a record, we need to fetch the remote file first and then respond from the local file
err := r.fetchRemoteFileAndStoreMetadata( err := r.fetchRemoteFileAndStoreMetadata(
ctx, cfg.Media.AbsBasePath, *cfg.Media.MaxFileSizeBytes, db, ctx, client,
cfg.Media.AbsBasePath, *cfg.Media.MaxFileSizeBytes, db,
cfg.Media.ThumbnailSizes, activeThumbnailGeneration, cfg.Media.ThumbnailSizes, activeThumbnailGeneration,
cfg.Media.MaxThumbnailGenerators, cfg.Media.MaxThumbnailGenerators,
) )
@ -541,6 +546,7 @@ func (r *downloadRequest) broadcastMediaMetadata(activeRemoteRequests *types.Act
// fetchRemoteFileAndStoreMetadata fetches the file from the remote server and stores its metadata in the database // fetchRemoteFileAndStoreMetadata fetches the file from the remote server and stores its metadata in the database
func (r *downloadRequest) fetchRemoteFileAndStoreMetadata( func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
ctx context.Context, ctx context.Context,
client *gomatrixserverlib.Client,
absBasePath config.Path, absBasePath config.Path,
maxFileSizeBytes config.FileSizeBytes, maxFileSizeBytes config.FileSizeBytes,
db *storage.Database, db *storage.Database,
@ -548,7 +554,9 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
activeThumbnailGeneration *types.ActiveThumbnailGeneration, activeThumbnailGeneration *types.ActiveThumbnailGeneration,
maxThumbnailGenerators int, maxThumbnailGenerators int,
) error { ) error {
finalPath, duplicate, err := r.fetchRemoteFile(absBasePath, maxFileSizeBytes) finalPath, duplicate, err := r.fetchRemoteFile(
ctx, client, absBasePath, maxFileSizeBytes,
)
if err != nil { if err != nil {
return err return err
} }
@ -597,11 +605,16 @@ func (r *downloadRequest) fetchRemoteFileAndStoreMetadata(
return nil return nil
} }
func (r *downloadRequest) fetchRemoteFile(absBasePath config.Path, maxFileSizeBytes config.FileSizeBytes) (types.Path, bool, error) { func (r *downloadRequest) fetchRemoteFile(
ctx context.Context,
client *gomatrixserverlib.Client,
absBasePath config.Path,
maxFileSizeBytes config.FileSizeBytes,
) (types.Path, bool, error) {
r.Logger.Info("Fetching remote file") r.Logger.Info("Fetching remote file")
// create request for remote file // create request for remote file
resp, err := r.createRemoteRequest() resp, err := r.createRemoteRequest(ctx, client)
if err != nil { if err != nil {
return "", false, err return "", false, err
} }
@ -664,10 +677,10 @@ func (r *downloadRequest) fetchRemoteFile(absBasePath config.Path, maxFileSizeBy
return types.Path(finalPath), duplicate, nil return types.Path(finalPath), duplicate, nil
} }
func (r *downloadRequest) createRemoteRequest() (*http.Response, error) { func (r *downloadRequest) createRemoteRequest(
matrixClient := gomatrixserverlib.NewClient() ctx context.Context, matrixClient *gomatrixserverlib.Client,
) (*http.Response, error) {
resp, err := matrixClient.CreateMediaDownloadRequest(r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID)) resp, err := matrixClient.CreateMediaDownloadRequest(ctx, r.MediaMetadata.Origin, string(r.MediaMetadata.MediaID))
if err != nil { if err != nil {
return nil, fmt.Errorf("file with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin) return nil, fmt.Errorf("file with media ID %q could not be downloaded from %q", r.MediaMetadata.MediaID, r.MediaMetadata.Origin)
} }

2
vendor/manifest vendored
View file

@ -116,7 +116,7 @@
{ {
"importpath": "github.com/matrix-org/gomatrixserverlib", "importpath": "github.com/matrix-org/gomatrixserverlib",
"repository": "https://github.com/matrix-org/gomatrixserverlib", "repository": "https://github.com/matrix-org/gomatrixserverlib",
"revision": "ec5a0d21b03ed4d3bd955ecc9f7a69936f64391e", "revision": "40b35e1c997fc7e35342aeb39187ff6bf3e10b2e",
"branch": "master" "branch": "master"
}, },
{ {

View file

@ -236,9 +236,15 @@ func (fc *Client) LookupServerKeys( // nolint: gocyclo
} }
// CreateMediaDownloadRequest creates a request for media on a homeserver and returns the http.Response or an error // CreateMediaDownloadRequest creates a request for media on a homeserver and returns the http.Response or an error
func (fc *Client) CreateMediaDownloadRequest(matrixServer ServerName, mediaID string) (*http.Response, error) { func (fc *Client) CreateMediaDownloadRequest(
ctx context.Context, matrixServer ServerName, mediaID string,
) (*http.Response, error) {
requestURL := "matrix://" + string(matrixServer) + "/_matrix/media/v1/download/" + string(matrixServer) + "/" + mediaID requestURL := "matrix://" + string(matrixServer) + "/_matrix/media/v1/download/" + string(matrixServer) + "/" + mediaID
resp, err := fc.client.Get(requestURL) req, err := http.NewRequest("GET", requestURL, nil)
if err != nil {
return nil, err
}
resp, err := fc.client.Do(req.WithContext(ctx))
if err != nil { if err != nil {
return nil, err return nil, err
} }