- Make sure we always cleanup the temp directory on error.
- Complain about it having an error prone API shape.
This commit is contained in:
Kegsay 2020-08-26 15:38:34 +01:00 committed by GitHub
parent 29d6481842
commit 3802efe301
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 36 additions and 16 deletions

View file

@ -16,6 +16,7 @@ package fileutils
import (
"bufio"
"context"
"crypto/sha256"
"encoding/base64"
"fmt"
@ -27,6 +28,7 @@ import (
"github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/util"
log "github.com/sirupsen/logrus"
)
@ -104,15 +106,23 @@ func RemoveDir(dir types.Path, logger *log.Entry) {
}
}
// WriteTempFile writes to a new temporary file
func WriteTempFile(reqReader io.Reader, maxFileSizeBytes config.FileSizeBytes, absBasePath config.Path) (hash types.Base64Hash, size types.FileSizeBytes, path types.Path, err error) {
// WriteTempFile writes to a new temporary file.
// The file is deleted if there was an error while writing.
func WriteTempFile(
ctx context.Context, reqReader io.Reader, maxFileSizeBytes config.FileSizeBytes, absBasePath config.Path,
) (hash types.Base64Hash, size types.FileSizeBytes, path types.Path, err error) {
size = -1
logger := util.GetLogger(ctx)
tmpFileWriter, tmpFile, tmpDir, err := createTempFileWriter(absBasePath)
if err != nil {
return
}
defer (func() { err = tmpFile.Close() })()
defer func() {
err2 := tmpFile.Close()
if err == nil {
err = err2
}
}()
// The amount of data read is limited to maxFileSizeBytes. At this point, if there is more data it will be truncated.
limitedReader := io.LimitReader(reqReader, int64(maxFileSizeBytes))
@ -123,11 +133,13 @@ func WriteTempFile(reqReader io.Reader, maxFileSizeBytes config.FileSizeBytes, a
teeReader := io.TeeReader(limitedReader, hasher)
bytesWritten, err := io.Copy(tmpFileWriter, teeReader)
if err != nil && err != io.EOF {
RemoveDir(tmpDir, logger)
return
}
err = tmpFileWriter.Flush()
if err != nil {
RemoveDir(tmpDir, logger)
return
}