Merge branch 'master' into neilalexander/servername

This commit is contained in:
Neil Alexander 2021-06-14 14:12:35 +01:00 committed by GitHub
commit 043759af60
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 414 additions and 96 deletions

View file

@ -1,7 +1,7 @@
# Dendrite [![Build Status](https://badge.buildkite.com/4be40938ab19f2bbc4a6c6724517353ee3ec1422e279faf374.svg?branch=master)](https://buildkite.com/matrix-dot-org/dendrite) [![Dendrite](https://img.shields.io/matrix/dendrite:matrix.org.svg?label=%23dendrite%3Amatrix.org&logo=matrix&server_fqdn=matrix.org)](https://matrix.to/#/#dendrite:matrix.org) [![Dendrite Dev](https://img.shields.io/matrix/dendrite-dev:matrix.org.svg?label=%23dendrite-dev%3Amatrix.org&logo=matrix&server_fqdn=matrix.org)](https://matrix.to/#/#dendrite-dev:matrix.org)
Dendrite is a second-generation Matrix homeserver written in Go.
It intends to provide an **efficient**, **reliable** and **scalable** alternative to Synapse:
It intends to provide an **efficient**, **reliable** and **scalable** alternative to [Synapse](https://github.com/matrix-org/synapse):
- Efficient: A small memory footprint with better baseline performance than an out-of-the-box Synapse.
- Reliable: Implements the Matrix specification as written, using the
[same test suite](https://github.com/matrix-org/sytest) as Synapse as well as

View file

@ -32,14 +32,18 @@ INSERT OR IGNORE INTO appservice_counters (name, last_id) VALUES('txn_id', 1);
`
const selectTxnIDSQL = `
SELECT last_id FROM appservice_counters WHERE name='txn_id';
UPDATE appservice_counters SET last_id=last_id+1 WHERE name='txn_id';
SELECT last_id FROM appservice_counters WHERE name='txn_id'
`
const updateTxnIDSQL = `
UPDATE appservice_counters SET last_id=last_id+1 WHERE name='txn_id'
`
type txnStatements struct {
db *sql.DB
writer sqlutil.Writer
selectTxnIDStmt *sql.Stmt
updateTxnIDStmt *sql.Stmt
}
func (s *txnStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
@ -54,6 +58,10 @@ func (s *txnStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
return
}
if s.updateTxnIDStmt, err = db.Prepare(updateTxnIDSQL); err != nil {
return
}
return
}
@ -63,6 +71,11 @@ func (s *txnStatements) selectTxnID(
) (txnID int, err error) {
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
err := s.selectTxnIDStmt.QueryRowContext(ctx).Scan(&txnID)
if err != nil {
return err
}
_, err = s.updateTxnIDStmt.ExecContext(ctx)
return err
})
return

View file

@ -37,19 +37,21 @@ runtime config should come from. The mounted folder must contain:
To generate keys:
```
go run github.com/matrix-org/dendrite/cmd/generate-keys \
--private-key=matrix_key.pem \
--tls-cert=server.crt \
--tls-key=server.key
docker run --rm --entrypoint="" \
-v $(pwd):/mnt \
matrixdotorg/dendrite-monolith:latest \
/usr/bin/generate-keys \
-private-key /mnt/matrix_key.pem \
-tls-cert /mnt/server.crt \
-tls-key /mnt/server.key
```
The key files will now exist in your current working directory, and can be mounted into place.
## Starting Dendrite as a monolith deployment
Create your config based on the `dendrite.yaml` configuration file in the `docker/config`
folder in the [Dendrite repository](https://github.com/matrix-org/dendrite). Additionally,
make the following changes to the configuration:
- Enable Naffka: `use_naffka: true`
folder in the [Dendrite repository](https://github.com/matrix-org/dendrite).
Once in place, start the PostgreSQL dependency:

View file

@ -10,7 +10,6 @@ import (
"io"
"io/ioutil"
"log"
"math"
"net"
"net/http"
"os"
@ -37,15 +36,14 @@ import (
userapiAPI "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"go.uber.org/atomic"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
pineconeMulticast "github.com/matrix-org/pinecone/multicast"
"github.com/matrix-org/pinecone/router"
pineconeRouter "github.com/matrix-org/pinecone/router"
pineconeSessions "github.com/matrix-org/pinecone/sessions"
"github.com/matrix-org/pinecone/types"
pineconeTypes "github.com/matrix-org/pinecone/types"
_ "golang.org/x/mobile/bind"
)
@ -57,19 +55,19 @@ const (
)
type DendriteMonolith struct {
logger logrus.Logger
PineconeRouter *pineconeRouter.Router
PineconeMulticast *pineconeMulticast.Multicast
PineconeQUIC *pineconeSessions.Sessions
StorageDirectory string
CacheDirectory string
staticPeerURI string
staticPeerMutex sync.RWMutex
staticPeerAttempts atomic.Uint32
listener net.Listener
httpServer *http.Server
processContext *process.ProcessContext
userAPI userapiAPI.UserInternalAPI
logger logrus.Logger
PineconeRouter *pineconeRouter.Router
PineconeMulticast *pineconeMulticast.Multicast
PineconeQUIC *pineconeSessions.Sessions
StorageDirectory string
CacheDirectory string
staticPeerURI string
staticPeerMutex sync.RWMutex
staticPeerAttempt chan struct{}
listener net.Listener
httpServer *http.Server
processContext *process.ProcessContext
userAPI userapiAPI.UserInternalAPI
}
func (m *DendriteMonolith) BaseURL() string {
@ -99,7 +97,9 @@ func (m *DendriteMonolith) SetStaticPeer(uri string) {
m.staticPeerMutex.Unlock()
m.DisconnectType(pineconeRouter.PeerTypeRemote)
if uri != "" {
go m.staticPeerConnect()
go func() {
m.staticPeerAttempt <- struct{}{}
}()
}
}
@ -195,17 +195,27 @@ func (m *DendriteMonolith) RegisterDevice(localpart, deviceID string) (string, e
}
func (m *DendriteMonolith) staticPeerConnect() {
m.staticPeerMutex.RLock()
uri := m.staticPeerURI
m.staticPeerMutex.RUnlock()
if uri == "" {
return
attempt := func() {
if m.PineconeRouter.PeerCount(router.PeerTypeRemote) == 0 {
m.staticPeerMutex.RLock()
uri := m.staticPeerURI
m.staticPeerMutex.RUnlock()
if uri == "" {
return
}
if err := conn.ConnectToPeer(m.PineconeRouter, uri); err != nil {
logrus.WithError(err).Error("Failed to connect to static peer")
}
}
}
if err := conn.ConnectToPeer(m.PineconeRouter, uri); err != nil {
exp := time.Second * time.Duration(math.Exp2(float64(m.staticPeerAttempts.Inc())))
time.AfterFunc(exp, m.staticPeerConnect)
} else {
m.staticPeerAttempts.Store(0)
for {
select {
case <-m.processContext.Context().Done():
case <-m.staticPeerAttempt:
attempt()
case <-time.After(time.Second * 5):
attempt()
}
}
}
@ -248,13 +258,6 @@ func (m *DendriteMonolith) Start() {
m.PineconeQUIC = pineconeSessions.NewSessions(logger, m.PineconeRouter)
m.PineconeMulticast = pineconeMulticast.NewMulticast(logger, m.PineconeRouter)
m.PineconeRouter.SetDisconnectedCallback(func(port pineconeTypes.SwitchPortID, public pineconeTypes.PublicKey, peertype int, err error) {
if peertype == pineconeRouter.PeerTypeRemote {
m.staticPeerAttempts.Store(0)
time.AfterFunc(time.Second, m.staticPeerConnect)
}
})
prefix := hex.EncodeToString(pk)
cfg := &config.Dendrite{}
cfg.Defaults()
@ -359,8 +362,12 @@ func (m *DendriteMonolith) Start() {
},
Handler: h2c.NewHandler(pMux, h2s),
}
m.processContext = base.ProcessContext
m.staticPeerAttempt = make(chan struct{}, 1)
go m.staticPeerConnect()
go func() {
m.logger.Info("Listening on ", cfg.Global.ServerName)
m.logger.Fatal(m.httpServer.Serve(m.PineconeQUIC))

View file

@ -69,7 +69,7 @@ func GetAccountData(
return util.JSONResponse{
Code: http.StatusNotFound,
JSON: jsonerror.Forbidden("data not found"),
JSON: jsonerror.NotFound("data not found"),
}
}

View file

@ -18,13 +18,18 @@ import (
"context"
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"strings"
"syscall"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
"golang.org/x/term"
)
const usage = `Usage: %s
@ -33,7 +38,15 @@ Creates a new user account on the homeserver.
Example:
./create-account --config dendrite.yaml --username alice --password foobarbaz
# provide password by parameter
%s --config dendrite.yaml -username alice -password foobarbaz
# use password from file
%s --config dendrite.yaml -username alice -passwordfile my.pass
# ask user to provide password
%s --config dendrite.yaml -username alice -ask-pass
# read password from stdin
%s --config dendrite.yaml -username alice -passwordstdin < my.pass
cat my.pass | %s --config dendrite.yaml -username alice -passwordstdin
Arguments:
@ -42,11 +55,15 @@ Arguments:
var (
username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')")
password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)")
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
askPass = flag.Bool("ask-pass", false, "Ask for the password to use")
)
func main() {
name := os.Args[0]
flag.Usage = func() {
fmt.Fprintf(os.Stderr, usage, os.Args[0])
_, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name)
flag.PrintDefaults()
}
cfg := setup.ParseFlags(true)
@ -56,6 +73,8 @@ func main() {
os.Exit(1)
}
pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin)
accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{
ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString,
}, cfg.Global.ServerName, bcrypt.DefaultCost, cfg.UserAPI.OpenIDTokenLifetimeMS)
@ -63,10 +82,61 @@ func main() {
logrus.Fatalln("Failed to connect to the database:", err.Error())
}
_, err = accountDB.CreateAccount(context.Background(), *username, *password, "")
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "")
if err != nil {
logrus.Fatalln("Failed to create the account:", err.Error())
}
logrus.Infoln("Created account", *username)
}
func getPassword(password, pwdFile *string, pwdStdin, askPass *bool, r io.Reader) string {
// no password option set, use empty password
if password == nil && pwdFile == nil && pwdStdin == nil && askPass == nil {
return ""
}
// password defined as parameter
if password != nil && *password != "" {
return *password
}
// read password from file
if pwdFile != nil && *pwdFile != "" {
pw, err := ioutil.ReadFile(*pwdFile)
if err != nil {
logrus.Fatalln("Unable to read password from file:", err)
}
return strings.TrimSpace(string(pw))
}
// read password from stdin
if pwdStdin != nil && *pwdStdin {
data, err := ioutil.ReadAll(r)
if err != nil {
logrus.Fatalln("Unable to read password from stdin:", err)
}
return strings.TrimSpace(string(data))
}
// ask the user to provide the password
if *askPass {
fmt.Print("Enter Password: ")
bytePassword, err := term.ReadPassword(syscall.Stdin)
if err != nil {
logrus.Fatalln("Unable to read password:", err)
}
fmt.Println()
fmt.Print("Confirm Password: ")
bytePassword2, err := term.ReadPassword(syscall.Stdin)
if err != nil {
logrus.Fatalln("Unable to read password:", err)
}
fmt.Println()
if strings.TrimSpace(string(bytePassword)) != strings.TrimSpace(string(bytePassword2)) {
logrus.Fatalln("Entered passwords don't match")
}
return strings.TrimSpace(string(bytePassword))
}
return ""
}

View file

@ -0,0 +1,62 @@
package main
import (
"bytes"
"io"
"testing"
)
func Test_getPassword(t *testing.T) {
type args struct {
password *string
pwdFile *string
pwdStdin *bool
askPass *bool
reader io.Reader
}
pass := "mySecretPass"
passwordFile := "testdata/my.pass"
passwordStdin := true
reader := &bytes.Buffer{}
_, err := reader.WriteString(pass)
if err != nil {
t.Errorf("unable to write to buffer: %+v", err)
}
tests := []struct {
name string
args args
want string
}{
{
name: "no password defined",
args: args{},
want: "",
},
{
name: "password defined",
args: args{password: &pass},
want: pass,
},
{
name: "pwdFile defined",
args: args{pwdFile: &passwordFile},
want: pass,
},
{
name: "read pass from stdin defined",
args: args{
pwdStdin: &passwordStdin,
reader: reader,
},
want: pass,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getPassword(tt.args.password, tt.args.pwdFile, tt.args.pwdStdin, tt.args.askPass, tt.args.reader); got != tt.want {
t.Errorf("getPassword() = '%v', want '%v'", got, tt.want)
}
})
}
}

1
cmd/create-account/testdata/my.pass vendored Normal file
View file

@ -0,0 +1 @@
mySecretPass

View file

@ -23,7 +23,6 @@ import (
"fmt"
"io/ioutil"
"log"
"math"
"net"
"net/http"
"os"
@ -48,12 +47,11 @@ import (
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/gomatrixserverlib"
"go.uber.org/atomic"
pineconeMulticast "github.com/matrix-org/pinecone/multicast"
"github.com/matrix-org/pinecone/router"
pineconeRouter "github.com/matrix-org/pinecone/router"
pineconeSessions "github.com/matrix-org/pinecone/sessions"
pineconeTypes "github.com/matrix-org/pinecone/types"
"github.com/sirupsen/logrus"
)
@ -123,27 +121,23 @@ func main() {
pMulticast := pineconeMulticast.NewMulticast(logger, pRouter)
pMulticast.Start()
var staticPeerAttempts atomic.Uint32
var connectToStaticPeer func()
connectToStaticPeer = func() {
uri := *instancePeer
if uri == "" {
return
connectToStaticPeer := func() {
attempt := func() {
if pRouter.PeerCount(router.PeerTypeRemote) == 0 {
uri := *instancePeer
if uri == "" {
return
}
if err := conn.ConnectToPeer(pRouter, uri); err != nil {
logrus.WithError(err).Error("Failed to connect to static peer")
}
}
}
if err := conn.ConnectToPeer(pRouter, uri); err != nil {
exp := time.Second * time.Duration(math.Exp2(float64(staticPeerAttempts.Inc())))
time.AfterFunc(exp, connectToStaticPeer)
} else {
staticPeerAttempts.Store(0)
for {
attempt()
time.Sleep(time.Second * 5)
}
}
pRouter.SetDisconnectedCallback(func(port pineconeTypes.SwitchPortID, public pineconeTypes.PublicKey, peertype int, err error) {
if peertype == pineconeRouter.PeerTypeRemote && err != nil {
staticPeerAttempts.Store(0)
time.AfterFunc(time.Second, connectToStaticPeer)
}
})
go connectToStaticPeer()
cfg := &config.Dendrite{}
cfg.Defaults()
@ -257,6 +251,7 @@ func main() {
Handler: pMux,
}
go connectToStaticPeer()
go func() {
pubkey := pRouter.PublicKey()
logrus.Info("Listening on ", hex.EncodeToString(pubkey[:]))

View file

@ -1,6 +1,6 @@
server {
listen 443 ssl; # IPv4
listen [::]:443; # IPv6
listen [::]:443 ssl; # IPv6
server_name my.hostname.com;
ssl_certificate /path/to/fullchain.pem;
@ -16,6 +16,9 @@ server {
}
location /.well-known/matrix/client {
# If your sever_name here doesn't match your matrix homeserver URL
# (e.g. hostname.com as server_name and matrix.hostname.com as homeserver URL)
# add_header Access-Control-Allow-Origin '*';
return 200 '{ "m.homeserver": { "base_url": "https://my.hostname.com" } }';
}

View file

@ -1,6 +1,6 @@
server {
listen 443 ssl; # IPv4
listen [::]:443; # IPv6
listen [::]:443 ssl; # IPv6
server_name my.hostname.com;
ssl_certificate /path/to/fullchain.pem;
@ -16,6 +16,9 @@ server {
}
location /.well-known/matrix/client {
# If your sever_name here doesn't match your matrix homeserver URL
# (e.g. hostname.com as server_name and matrix.hostname.com as homeserver URL)
# add_header Access-Control-Allow-Origin '*';
return 200 '{ "m.homeserver": { "base_url": "https://my.hostname.com" } }';
}

2
go.mod
View file

@ -25,7 +25,7 @@ require (
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd
github.com/matrix-org/gomatrixserverlib v0.0.0-20210525110027-8cb7699aa64a
github.com/matrix-org/naffka v0.0.0-20201009174903-d26a3b9cb161
github.com/matrix-org/pinecone v0.0.0-20210602111459-5cb0e6aa1a6a
github.com/matrix-org/pinecone v0.0.0-20210614122540-33ce3bd0f3ac
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646

4
go.sum
View file

@ -706,8 +706,8 @@ github.com/matrix-org/gomatrixserverlib v0.0.0-20210525110027-8cb7699aa64a h1:pV
github.com/matrix-org/gomatrixserverlib v0.0.0-20210525110027-8cb7699aa64a/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
github.com/matrix-org/naffka v0.0.0-20201009174903-d26a3b9cb161 h1:h1XVh05pLoC+nJjP3GIpj5wUsuC8WdHP3He0RTkRJTs=
github.com/matrix-org/naffka v0.0.0-20201009174903-d26a3b9cb161/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE=
github.com/matrix-org/pinecone v0.0.0-20210602111459-5cb0e6aa1a6a h1:BE/cfpyHO2ua1BK4Tibr+2oZCV3H1mC9G7g7Yvl1AmM=
github.com/matrix-org/pinecone v0.0.0-20210602111459-5cb0e6aa1a6a/go.mod h1:UQzJS6UVyVwfkr+RLrdvBB1vLyECqe3fLYNcbRxv8SA=
github.com/matrix-org/pinecone v0.0.0-20210614122540-33ce3bd0f3ac h1:qgEfJzulYUVDGh1PGzeGxYMGDtKSxMS+6eQG6E37pgM=
github.com/matrix-org/pinecone v0.0.0-20210614122540-33ce3bd0f3ac/go.mod h1:UQzJS6UVyVwfkr+RLrdvBB1vLyECqe3fLYNcbRxv8SA=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=

View file

@ -20,6 +20,7 @@ import (
"encoding/json"
"time"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
@ -47,7 +48,7 @@ const upsertKeysSQL = "" +
" DO UPDATE SET key_json = $6"
const selectKeysSQL = "" +
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
"SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);"
const selectKeysCountSQL = "" +
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
@ -94,29 +95,22 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
}
func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID)
rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID, pq.Array(keyIDsWithAlgorithms))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed")
wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
for _, ka := range keyIDsWithAlgorithms {
wantSet[ka] = true
}
result := make(map[string]json.RawMessage)
var (
algorithmWithID string
keyJSONStr string
)
for rows.Next() {
var keyID string
var algorithm string
var keyJSONStr string
if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil {
if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil {
return nil, err
}
keyIDWithAlgo := algorithm + ":" + keyID
if wantSet[keyIDWithAlgo] {
result[keyIDWithAlgo] = json.RawMessage(keyJSONStr)
}
result[algorithmWithID] = json.RawMessage(keyJSONStr)
}
return result, rows.Err()
}

View file

@ -147,7 +147,20 @@ func (r *uploadRequest) doUpload(
// r.storeFileAndMetadata(ctx, tmpDir, ...)
// before you return from doUpload else we will leak a temp file. We could make this nicer with a `WithTransaction` style of
// nested function to guarantee either storage or cleanup.
hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reqReader, cfg.AbsBasePath)
// should not happen, but prevents any int overflows
if cfg.MaxFileSizeBytes != nil && *cfg.MaxFileSizeBytes+1 <= 0 {
r.Logger.WithFields(log.Fields{
"MaxFileSizeBytes": *cfg.MaxFileSizeBytes + 1,
}).Error("Error while transferring file, configured max_file_size_bytes overflows int64")
return &util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.Unknown("Failed to upload"),
}
}
lr := io.LimitReader(reqReader, int64(*cfg.MaxFileSizeBytes)+1)
hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, lr, cfg.AbsBasePath)
if err != nil {
r.Logger.WithError(err).WithFields(log.Fields{
"MaxFileSizeBytes": *cfg.MaxFileSizeBytes,

View file

@ -0,0 +1,132 @@
package routing
import (
"context"
"io"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
"github.com/matrix-org/dendrite/mediaapi/fileutils"
"github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/util"
log "github.com/sirupsen/logrus"
)
func Test_uploadRequest_doUpload(t *testing.T) {
type fields struct {
MediaMetadata *types.MediaMetadata
Logger *log.Entry
}
type args struct {
ctx context.Context
reqReader io.Reader
cfg *config.MediaAPI
db storage.Database
activeThumbnailGeneration *types.ActiveThumbnailGeneration
}
wd, err := os.Getwd()
if err != nil {
t.Errorf("failed to get current working directory: %v", err)
}
maxSize := config.FileSizeBytes(8)
logger := log.New().WithField("mediaapi", "test")
testdataPath := filepath.Join(wd, "./testdata")
cfg := &config.MediaAPI{
MaxFileSizeBytes: &maxSize,
BasePath: config.Path(testdataPath),
AbsBasePath: config.Path(testdataPath),
DynamicThumbnails: false,
}
// create testdata folder and remove when done
_ = os.Mkdir(testdataPath, os.ModePerm)
defer fileutils.RemoveDir(types.Path(testdataPath), nil)
db, err := storage.Open(&config.DatabaseOptions{
ConnectionString: "file::memory:?cache=shared",
MaxOpenConnections: 100,
MaxIdleConnections: 2,
ConnMaxLifetimeSeconds: -1,
})
if err != nil {
t.Errorf("error opening mediaapi database: %v", err)
}
tests := []struct {
name string
fields fields
args args
want *util.JSONResponse
}{
{
name: "upload ok",
args: args{
ctx: context.Background(),
reqReader: strings.NewReader("test"),
cfg: cfg,
db: db,
},
fields: fields{
Logger: logger,
MediaMetadata: &types.MediaMetadata{
MediaID: "1337",
UploadName: "test ok",
},
},
want: nil,
},
{
name: "upload ok (exact size)",
args: args{
ctx: context.Background(),
reqReader: strings.NewReader("testtest"),
cfg: cfg,
db: db,
},
fields: fields{
Logger: logger,
MediaMetadata: &types.MediaMetadata{
MediaID: "1338",
UploadName: "test ok (exact size)",
},
},
want: nil,
},
{
name: "upload not ok",
args: args{
ctx: context.Background(),
reqReader: strings.NewReader("test test test"),
cfg: cfg,
db: db,
},
fields: fields{
Logger: logger,
MediaMetadata: &types.MediaMetadata{
MediaID: "1339",
UploadName: "test fail",
},
},
want: requestEntityTooLargeJSONResponse(maxSize),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &uploadRequest{
MediaMetadata: tt.fields.MediaMetadata,
Logger: tt.fields.Logger,
}
if got := r.doUpload(tt.args.ctx, tt.args.reqReader, tt.args.cfg, tt.args.db, tt.args.activeThumbnailGeneration); !reflect.DeepEqual(got, tt.want) {
t.Errorf("doUpload() = %+v, want %+v", got, tt.want)
}
})
}
}

View file

@ -101,7 +101,7 @@ func (a *ApplicationService) IsInterestedInRoomID(
) bool {
if namespaceSlice, ok := a.NamespaceMap["rooms"]; ok {
for _, namespace := range namespaceSlice {
if namespace.RegexpObject.MatchString(roomID) {
if namespace.RegexpObject != nil && namespace.RegexpObject.MatchString(roomID) {
return true
}
}
@ -222,6 +222,10 @@ func setupRegexps(asAPI *AppServiceAPI, derived *Derived) (err error) {
case "aliases":
appendExclusiveNamespaceRegexs(&exclusiveAliasStrings, namespaceSlice)
}
if err = compileNamespaceRegexes(namespaceSlice); err != nil {
return fmt.Errorf("invalid regex in appservice %q, namespace %q: %w", appservice.ID, key, err)
}
}
}
@ -258,18 +262,31 @@ func setupRegexps(asAPI *AppServiceAPI, derived *Derived) (err error) {
func appendExclusiveNamespaceRegexs(
exclusiveStrings *[]string, namespaces []ApplicationServiceNamespace,
) {
for index, namespace := range namespaces {
for _, namespace := range namespaces {
if namespace.Exclusive {
// We append parenthesis to later separate each regex when we compile
// i.e. "app1.*", "app2.*" -> "(app1.*)|(app2.*)"
*exclusiveStrings = append(*exclusiveStrings, "("+namespace.Regex+")")
}
// Compile this regex into a Regexp object for later use
namespaces[index].RegexpObject, _ = regexp.Compile(namespace.Regex)
}
}
// compileNamespaceRegexes turns strings into regex objects and complains
// if some of there are bad
func compileNamespaceRegexes(namespaces []ApplicationServiceNamespace) (err error) {
for index, namespace := range namespaces {
// Compile this regex into a Regexp object for later use
r, err := regexp.Compile(namespace.Regex)
if err != nil {
return fmt.Errorf("regex at namespace %d: %w", index, err)
}
namespaces[index].RegexpObject = r
}
return nil
}
// checkErrors checks for any configuration errors amongst the loaded
// application services according to the application service spec.
func checkErrors(config *AppServiceAPI, derived *Derived) (err error) {

View file

@ -2,6 +2,7 @@ package config
import (
"fmt"
"math"
)
type MediaAPI struct {
@ -57,6 +58,11 @@ func (c *MediaAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
checkNotEmpty(configErrs, "media_api.database.connection_string", string(c.Database.ConnectionString))
checkNotEmpty(configErrs, "media_api.base_path", string(c.BasePath))
// allow "unlimited" file size
if c.MaxFileSizeBytes != nil && *c.MaxFileSizeBytes <= 0 {
unlimitedSize := FileSizeBytes(math.MaxInt64 - 1)
c.MaxFileSizeBytes = &unlimitedSize
}
checkPositive(configErrs, "media_api.max_file_size_bytes", int64(*c.MaxFileSizeBytes))
checkPositive(configErrs, "media_api.max_thumbnail_generators", int64(c.MaxThumbnailGenerators))