diff --git a/README.md b/README.md index 7d79bbff..8c793841 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Dendrite [![Build Status](https://badge.buildkite.com/4be40938ab19f2bbc4a6c6724517353ee3ec1422e279faf374.svg?branch=master)](https://buildkite.com/matrix-dot-org/dendrite) [![Dendrite](https://img.shields.io/matrix/dendrite:matrix.org.svg?label=%23dendrite%3Amatrix.org&logo=matrix&server_fqdn=matrix.org)](https://matrix.to/#/#dendrite:matrix.org) [![Dendrite Dev](https://img.shields.io/matrix/dendrite-dev:matrix.org.svg?label=%23dendrite-dev%3Amatrix.org&logo=matrix&server_fqdn=matrix.org)](https://matrix.to/#/#dendrite-dev:matrix.org) Dendrite is a second-generation Matrix homeserver written in Go. -It intends to provide an **efficient**, **reliable** and **scalable** alternative to Synapse: +It intends to provide an **efficient**, **reliable** and **scalable** alternative to [Synapse](https://github.com/matrix-org/synapse): - Efficient: A small memory footprint with better baseline performance than an out-of-the-box Synapse. - Reliable: Implements the Matrix specification as written, using the [same test suite](https://github.com/matrix-org/sytest) as Synapse as well as diff --git a/appservice/storage/sqlite3/txn_id_counter_table.go b/appservice/storage/sqlite3/txn_id_counter_table.go index b2940e35..f2e902f9 100644 --- a/appservice/storage/sqlite3/txn_id_counter_table.go +++ b/appservice/storage/sqlite3/txn_id_counter_table.go @@ -32,14 +32,18 @@ INSERT OR IGNORE INTO appservice_counters (name, last_id) VALUES('txn_id', 1); ` const selectTxnIDSQL = ` - SELECT last_id FROM appservice_counters WHERE name='txn_id'; - UPDATE appservice_counters SET last_id=last_id+1 WHERE name='txn_id'; + SELECT last_id FROM appservice_counters WHERE name='txn_id' +` + +const updateTxnIDSQL = ` + UPDATE appservice_counters SET last_id=last_id+1 WHERE name='txn_id' ` type txnStatements struct { db *sql.DB writer sqlutil.Writer selectTxnIDStmt *sql.Stmt + updateTxnIDStmt *sql.Stmt } func (s *txnStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { @@ -54,6 +58,10 @@ func (s *txnStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { return } + if s.updateTxnIDStmt, err = db.Prepare(updateTxnIDSQL); err != nil { + return + } + return } @@ -63,6 +71,11 @@ func (s *txnStatements) selectTxnID( ) (txnID int, err error) { err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { err := s.selectTxnIDStmt.QueryRowContext(ctx).Scan(&txnID) + if err != nil { + return err + } + + _, err = s.updateTxnIDStmt.ExecContext(ctx) return err }) return diff --git a/build/docker/README.md b/build/docker/README.md index 818f92d0..6d3cd3db 100644 --- a/build/docker/README.md +++ b/build/docker/README.md @@ -37,19 +37,21 @@ runtime config should come from. The mounted folder must contain: To generate keys: ``` -go run github.com/matrix-org/dendrite/cmd/generate-keys \ - --private-key=matrix_key.pem \ - --tls-cert=server.crt \ - --tls-key=server.key +docker run --rm --entrypoint="" \ + -v $(pwd):/mnt \ + matrixdotorg/dendrite-monolith:latest \ + /usr/bin/generate-keys \ + -private-key /mnt/matrix_key.pem \ + -tls-cert /mnt/server.crt \ + -tls-key /mnt/server.key ``` +The key files will now exist in your current working directory, and can be mounted into place. + ## Starting Dendrite as a monolith deployment Create your config based on the `dendrite.yaml` configuration file in the `docker/config` -folder in the [Dendrite repository](https://github.com/matrix-org/dendrite). Additionally, -make the following changes to the configuration: - -- Enable Naffka: `use_naffka: true` +folder in the [Dendrite repository](https://github.com/matrix-org/dendrite). Once in place, start the PostgreSQL dependency: diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index c15707e5..09af80f6 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -10,7 +10,6 @@ import ( "io" "io/ioutil" "log" - "math" "net" "net/http" "os" @@ -37,15 +36,14 @@ import ( userapiAPI "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" - "go.uber.org/atomic" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" pineconeMulticast "github.com/matrix-org/pinecone/multicast" + "github.com/matrix-org/pinecone/router" pineconeRouter "github.com/matrix-org/pinecone/router" pineconeSessions "github.com/matrix-org/pinecone/sessions" "github.com/matrix-org/pinecone/types" - pineconeTypes "github.com/matrix-org/pinecone/types" _ "golang.org/x/mobile/bind" ) @@ -57,19 +55,19 @@ const ( ) type DendriteMonolith struct { - logger logrus.Logger - PineconeRouter *pineconeRouter.Router - PineconeMulticast *pineconeMulticast.Multicast - PineconeQUIC *pineconeSessions.Sessions - StorageDirectory string - CacheDirectory string - staticPeerURI string - staticPeerMutex sync.RWMutex - staticPeerAttempts atomic.Uint32 - listener net.Listener - httpServer *http.Server - processContext *process.ProcessContext - userAPI userapiAPI.UserInternalAPI + logger logrus.Logger + PineconeRouter *pineconeRouter.Router + PineconeMulticast *pineconeMulticast.Multicast + PineconeQUIC *pineconeSessions.Sessions + StorageDirectory string + CacheDirectory string + staticPeerURI string + staticPeerMutex sync.RWMutex + staticPeerAttempt chan struct{} + listener net.Listener + httpServer *http.Server + processContext *process.ProcessContext + userAPI userapiAPI.UserInternalAPI } func (m *DendriteMonolith) BaseURL() string { @@ -99,7 +97,9 @@ func (m *DendriteMonolith) SetStaticPeer(uri string) { m.staticPeerMutex.Unlock() m.DisconnectType(pineconeRouter.PeerTypeRemote) if uri != "" { - go m.staticPeerConnect() + go func() { + m.staticPeerAttempt <- struct{}{} + }() } } @@ -195,17 +195,27 @@ func (m *DendriteMonolith) RegisterDevice(localpart, deviceID string) (string, e } func (m *DendriteMonolith) staticPeerConnect() { - m.staticPeerMutex.RLock() - uri := m.staticPeerURI - m.staticPeerMutex.RUnlock() - if uri == "" { - return + attempt := func() { + if m.PineconeRouter.PeerCount(router.PeerTypeRemote) == 0 { + m.staticPeerMutex.RLock() + uri := m.staticPeerURI + m.staticPeerMutex.RUnlock() + if uri == "" { + return + } + if err := conn.ConnectToPeer(m.PineconeRouter, uri); err != nil { + logrus.WithError(err).Error("Failed to connect to static peer") + } + } } - if err := conn.ConnectToPeer(m.PineconeRouter, uri); err != nil { - exp := time.Second * time.Duration(math.Exp2(float64(m.staticPeerAttempts.Inc()))) - time.AfterFunc(exp, m.staticPeerConnect) - } else { - m.staticPeerAttempts.Store(0) + for { + select { + case <-m.processContext.Context().Done(): + case <-m.staticPeerAttempt: + attempt() + case <-time.After(time.Second * 5): + attempt() + } } } @@ -248,13 +258,6 @@ func (m *DendriteMonolith) Start() { m.PineconeQUIC = pineconeSessions.NewSessions(logger, m.PineconeRouter) m.PineconeMulticast = pineconeMulticast.NewMulticast(logger, m.PineconeRouter) - m.PineconeRouter.SetDisconnectedCallback(func(port pineconeTypes.SwitchPortID, public pineconeTypes.PublicKey, peertype int, err error) { - if peertype == pineconeRouter.PeerTypeRemote { - m.staticPeerAttempts.Store(0) - time.AfterFunc(time.Second, m.staticPeerConnect) - } - }) - prefix := hex.EncodeToString(pk) cfg := &config.Dendrite{} cfg.Defaults() @@ -359,8 +362,12 @@ func (m *DendriteMonolith) Start() { }, Handler: h2c.NewHandler(pMux, h2s), } + m.processContext = base.ProcessContext + m.staticPeerAttempt = make(chan struct{}, 1) + go m.staticPeerConnect() + go func() { m.logger.Info("Listening on ", cfg.Global.ServerName) m.logger.Fatal(m.httpServer.Serve(m.PineconeQUIC)) diff --git a/clientapi/routing/account_data.go b/clientapi/routing/account_data.go index 22e63513..03025f1d 100644 --- a/clientapi/routing/account_data.go +++ b/clientapi/routing/account_data.go @@ -69,7 +69,7 @@ func GetAccountData( return util.JSONResponse{ Code: http.StatusNotFound, - JSON: jsonerror.Forbidden("data not found"), + JSON: jsonerror.NotFound("data not found"), } } diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index 060b82f9..a1e254f8 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -18,13 +18,18 @@ import ( "context" "flag" "fmt" + "io" + "io/ioutil" "os" + "strings" + "syscall" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" + "golang.org/x/term" ) const usage = `Usage: %s @@ -33,7 +38,15 @@ Creates a new user account on the homeserver. Example: - ./create-account --config dendrite.yaml --username alice --password foobarbaz + # provide password by parameter + %s --config dendrite.yaml -username alice -password foobarbaz + # use password from file + %s --config dendrite.yaml -username alice -passwordfile my.pass + # ask user to provide password + %s --config dendrite.yaml -username alice -ask-pass + # read password from stdin + %s --config dendrite.yaml -username alice -passwordstdin < my.pass + cat my.pass | %s --config dendrite.yaml -username alice -passwordstdin Arguments: @@ -42,11 +55,15 @@ Arguments: var ( username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)") + pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") + pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") + askPass = flag.Bool("ask-pass", false, "Ask for the password to use") ) func main() { + name := os.Args[0] flag.Usage = func() { - fmt.Fprintf(os.Stderr, usage, os.Args[0]) + _, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name) flag.PrintDefaults() } cfg := setup.ParseFlags(true) @@ -56,6 +73,8 @@ func main() { os.Exit(1) } + pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin) + accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{ ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString, }, cfg.Global.ServerName, bcrypt.DefaultCost, cfg.UserAPI.OpenIDTokenLifetimeMS) @@ -63,10 +82,61 @@ func main() { logrus.Fatalln("Failed to connect to the database:", err.Error()) } - _, err = accountDB.CreateAccount(context.Background(), *username, *password, "") + _, err = accountDB.CreateAccount(context.Background(), *username, pass, "") if err != nil { logrus.Fatalln("Failed to create the account:", err.Error()) } logrus.Infoln("Created account", *username) } + +func getPassword(password, pwdFile *string, pwdStdin, askPass *bool, r io.Reader) string { + // no password option set, use empty password + if password == nil && pwdFile == nil && pwdStdin == nil && askPass == nil { + return "" + } + // password defined as parameter + if password != nil && *password != "" { + return *password + } + + // read password from file + if pwdFile != nil && *pwdFile != "" { + pw, err := ioutil.ReadFile(*pwdFile) + if err != nil { + logrus.Fatalln("Unable to read password from file:", err) + } + return strings.TrimSpace(string(pw)) + } + + // read password from stdin + if pwdStdin != nil && *pwdStdin { + data, err := ioutil.ReadAll(r) + if err != nil { + logrus.Fatalln("Unable to read password from stdin:", err) + } + return strings.TrimSpace(string(data)) + } + + // ask the user to provide the password + if *askPass { + fmt.Print("Enter Password: ") + bytePassword, err := term.ReadPassword(syscall.Stdin) + if err != nil { + logrus.Fatalln("Unable to read password:", err) + } + fmt.Println() + fmt.Print("Confirm Password: ") + bytePassword2, err := term.ReadPassword(syscall.Stdin) + if err != nil { + logrus.Fatalln("Unable to read password:", err) + } + fmt.Println() + if strings.TrimSpace(string(bytePassword)) != strings.TrimSpace(string(bytePassword2)) { + logrus.Fatalln("Entered passwords don't match") + } + return strings.TrimSpace(string(bytePassword)) + } + + return "" +} diff --git a/cmd/create-account/main_test.go b/cmd/create-account/main_test.go new file mode 100644 index 00000000..d06eafe4 --- /dev/null +++ b/cmd/create-account/main_test.go @@ -0,0 +1,62 @@ +package main + +import ( + "bytes" + "io" + "testing" +) + +func Test_getPassword(t *testing.T) { + type args struct { + password *string + pwdFile *string + pwdStdin *bool + askPass *bool + reader io.Reader + } + + pass := "mySecretPass" + passwordFile := "testdata/my.pass" + passwordStdin := true + reader := &bytes.Buffer{} + _, err := reader.WriteString(pass) + if err != nil { + t.Errorf("unable to write to buffer: %+v", err) + } + tests := []struct { + name string + args args + want string + }{ + { + name: "no password defined", + args: args{}, + want: "", + }, + { + name: "password defined", + args: args{password: &pass}, + want: pass, + }, + { + name: "pwdFile defined", + args: args{pwdFile: &passwordFile}, + want: pass, + }, + { + name: "read pass from stdin defined", + args: args{ + pwdStdin: &passwordStdin, + reader: reader, + }, + want: pass, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getPassword(tt.args.password, tt.args.pwdFile, tt.args.pwdStdin, tt.args.askPass, tt.args.reader); got != tt.want { + t.Errorf("getPassword() = '%v', want '%v'", got, tt.want) + } + }) + } +} diff --git a/cmd/create-account/testdata/my.pass b/cmd/create-account/testdata/my.pass new file mode 100644 index 00000000..c1f7156f --- /dev/null +++ b/cmd/create-account/testdata/my.pass @@ -0,0 +1 @@ +mySecretPass \ No newline at end of file diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 7851fdb1..f40f9190 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -23,7 +23,6 @@ import ( "fmt" "io/ioutil" "log" - "math" "net" "net/http" "os" @@ -48,12 +47,11 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" - "go.uber.org/atomic" pineconeMulticast "github.com/matrix-org/pinecone/multicast" + "github.com/matrix-org/pinecone/router" pineconeRouter "github.com/matrix-org/pinecone/router" pineconeSessions "github.com/matrix-org/pinecone/sessions" - pineconeTypes "github.com/matrix-org/pinecone/types" "github.com/sirupsen/logrus" ) @@ -123,27 +121,23 @@ func main() { pMulticast := pineconeMulticast.NewMulticast(logger, pRouter) pMulticast.Start() - var staticPeerAttempts atomic.Uint32 - var connectToStaticPeer func() - connectToStaticPeer = func() { - uri := *instancePeer - if uri == "" { - return + connectToStaticPeer := func() { + attempt := func() { + if pRouter.PeerCount(router.PeerTypeRemote) == 0 { + uri := *instancePeer + if uri == "" { + return + } + if err := conn.ConnectToPeer(pRouter, uri); err != nil { + logrus.WithError(err).Error("Failed to connect to static peer") + } + } } - if err := conn.ConnectToPeer(pRouter, uri); err != nil { - exp := time.Second * time.Duration(math.Exp2(float64(staticPeerAttempts.Inc()))) - time.AfterFunc(exp, connectToStaticPeer) - } else { - staticPeerAttempts.Store(0) + for { + attempt() + time.Sleep(time.Second * 5) } } - pRouter.SetDisconnectedCallback(func(port pineconeTypes.SwitchPortID, public pineconeTypes.PublicKey, peertype int, err error) { - if peertype == pineconeRouter.PeerTypeRemote && err != nil { - staticPeerAttempts.Store(0) - time.AfterFunc(time.Second, connectToStaticPeer) - } - }) - go connectToStaticPeer() cfg := &config.Dendrite{} cfg.Defaults() @@ -257,6 +251,7 @@ func main() { Handler: pMux, } + go connectToStaticPeer() go func() { pubkey := pRouter.PublicKey() logrus.Info("Listening on ", hex.EncodeToString(pubkey[:])) diff --git a/docs/nginx/monolith-sample.conf b/docs/nginx/monolith-sample.conf index 350e8348..0344aa96 100644 --- a/docs/nginx/monolith-sample.conf +++ b/docs/nginx/monolith-sample.conf @@ -1,6 +1,6 @@ server { listen 443 ssl; # IPv4 - listen [::]:443; # IPv6 + listen [::]:443 ssl; # IPv6 server_name my.hostname.com; ssl_certificate /path/to/fullchain.pem; @@ -16,6 +16,9 @@ server { } location /.well-known/matrix/client { + # If your sever_name here doesn't match your matrix homeserver URL + # (e.g. hostname.com as server_name and matrix.hostname.com as homeserver URL) + # add_header Access-Control-Allow-Origin '*'; return 200 '{ "m.homeserver": { "base_url": "https://my.hostname.com" } }'; } diff --git a/docs/nginx/polylith-sample.conf b/docs/nginx/polylith-sample.conf index d0d3c98d..274d7565 100644 --- a/docs/nginx/polylith-sample.conf +++ b/docs/nginx/polylith-sample.conf @@ -1,6 +1,6 @@ server { listen 443 ssl; # IPv4 - listen [::]:443; # IPv6 + listen [::]:443 ssl; # IPv6 server_name my.hostname.com; ssl_certificate /path/to/fullchain.pem; @@ -16,6 +16,9 @@ server { } location /.well-known/matrix/client { + # If your sever_name here doesn't match your matrix homeserver URL + # (e.g. hostname.com as server_name and matrix.hostname.com as homeserver URL) + # add_header Access-Control-Allow-Origin '*'; return 200 '{ "m.homeserver": { "base_url": "https://my.hostname.com" } }'; } diff --git a/go.mod b/go.mod index 7273da64..90bccf4f 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,7 @@ require ( github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd github.com/matrix-org/gomatrixserverlib v0.0.0-20210525110027-8cb7699aa64a github.com/matrix-org/naffka v0.0.0-20201009174903-d26a3b9cb161 - github.com/matrix-org/pinecone v0.0.0-20210602111459-5cb0e6aa1a6a + github.com/matrix-org/pinecone v0.0.0-20210614122540-33ce3bd0f3ac github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 diff --git a/go.sum b/go.sum index eba9a60b..b178ebe4 100644 --- a/go.sum +++ b/go.sum @@ -706,8 +706,8 @@ github.com/matrix-org/gomatrixserverlib v0.0.0-20210525110027-8cb7699aa64a h1:pV github.com/matrix-org/gomatrixserverlib v0.0.0-20210525110027-8cb7699aa64a/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/naffka v0.0.0-20201009174903-d26a3b9cb161 h1:h1XVh05pLoC+nJjP3GIpj5wUsuC8WdHP3He0RTkRJTs= github.com/matrix-org/naffka v0.0.0-20201009174903-d26a3b9cb161/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE= -github.com/matrix-org/pinecone v0.0.0-20210602111459-5cb0e6aa1a6a h1:BE/cfpyHO2ua1BK4Tibr+2oZCV3H1mC9G7g7Yvl1AmM= -github.com/matrix-org/pinecone v0.0.0-20210602111459-5cb0e6aa1a6a/go.mod h1:UQzJS6UVyVwfkr+RLrdvBB1vLyECqe3fLYNcbRxv8SA= +github.com/matrix-org/pinecone v0.0.0-20210614122540-33ce3bd0f3ac h1:qgEfJzulYUVDGh1PGzeGxYMGDtKSxMS+6eQG6E37pgM= +github.com/matrix-org/pinecone v0.0.0-20210614122540-33ce3bd0f3ac/go.mod h1:UQzJS6UVyVwfkr+RLrdvBB1vLyECqe3fLYNcbRxv8SA= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go index 6e32838b..cc397ba8 100644 --- a/keyserver/storage/postgres/one_time_keys_table.go +++ b/keyserver/storage/postgres/one_time_keys_table.go @@ -20,6 +20,7 @@ import ( "encoding/json" "time" + "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" @@ -47,7 +48,7 @@ const upsertKeysSQL = "" + " DO UPDATE SET key_json = $6" const selectKeysSQL = "" + - "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2" + "SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);" const selectKeysCountSQL = "" + "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm" @@ -94,29 +95,22 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { } func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { - rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID) + rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID, pq.Array(keyIDsWithAlgorithms)) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed") - wantSet := make(map[string]bool, len(keyIDsWithAlgorithms)) - for _, ka := range keyIDsWithAlgorithms { - wantSet[ka] = true - } - result := make(map[string]json.RawMessage) + var ( + algorithmWithID string + keyJSONStr string + ) for rows.Next() { - var keyID string - var algorithm string - var keyJSONStr string - if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil { + if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil { return nil, err } - keyIDWithAlgo := algorithm + ":" + keyID - if wantSet[keyIDWithAlgo] { - result[keyIDWithAlgo] = json.RawMessage(keyJSONStr) - } + result[algorithmWithID] = json.RawMessage(keyJSONStr) } return result, rows.Err() } diff --git a/mediaapi/routing/upload.go b/mediaapi/routing/upload.go index f1dd231d..bc0d206b 100644 --- a/mediaapi/routing/upload.go +++ b/mediaapi/routing/upload.go @@ -147,7 +147,20 @@ func (r *uploadRequest) doUpload( // r.storeFileAndMetadata(ctx, tmpDir, ...) // before you return from doUpload else we will leak a temp file. We could make this nicer with a `WithTransaction` style of // nested function to guarantee either storage or cleanup. - hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reqReader, cfg.AbsBasePath) + + // should not happen, but prevents any int overflows + if cfg.MaxFileSizeBytes != nil && *cfg.MaxFileSizeBytes+1 <= 0 { + r.Logger.WithFields(log.Fields{ + "MaxFileSizeBytes": *cfg.MaxFileSizeBytes + 1, + }).Error("Error while transferring file, configured max_file_size_bytes overflows int64") + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.Unknown("Failed to upload"), + } + } + + lr := io.LimitReader(reqReader, int64(*cfg.MaxFileSizeBytes)+1) + hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, lr, cfg.AbsBasePath) if err != nil { r.Logger.WithError(err).WithFields(log.Fields{ "MaxFileSizeBytes": *cfg.MaxFileSizeBytes, diff --git a/mediaapi/routing/upload_test.go b/mediaapi/routing/upload_test.go new file mode 100644 index 00000000..032437b5 --- /dev/null +++ b/mediaapi/routing/upload_test.go @@ -0,0 +1,132 @@ +package routing + +import ( + "context" + "io" + "os" + "path/filepath" + "reflect" + "strings" + "testing" + + "github.com/matrix-org/dendrite/mediaapi/fileutils" + "github.com/matrix-org/dendrite/mediaapi/storage" + "github.com/matrix-org/dendrite/mediaapi/types" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/util" + log "github.com/sirupsen/logrus" +) + +func Test_uploadRequest_doUpload(t *testing.T) { + type fields struct { + MediaMetadata *types.MediaMetadata + Logger *log.Entry + } + type args struct { + ctx context.Context + reqReader io.Reader + cfg *config.MediaAPI + db storage.Database + activeThumbnailGeneration *types.ActiveThumbnailGeneration + } + + wd, err := os.Getwd() + if err != nil { + t.Errorf("failed to get current working directory: %v", err) + } + + maxSize := config.FileSizeBytes(8) + logger := log.New().WithField("mediaapi", "test") + testdataPath := filepath.Join(wd, "./testdata") + + cfg := &config.MediaAPI{ + MaxFileSizeBytes: &maxSize, + BasePath: config.Path(testdataPath), + AbsBasePath: config.Path(testdataPath), + DynamicThumbnails: false, + } + + // create testdata folder and remove when done + _ = os.Mkdir(testdataPath, os.ModePerm) + defer fileutils.RemoveDir(types.Path(testdataPath), nil) + + db, err := storage.Open(&config.DatabaseOptions{ + ConnectionString: "file::memory:?cache=shared", + MaxOpenConnections: 100, + MaxIdleConnections: 2, + ConnMaxLifetimeSeconds: -1, + }) + if err != nil { + t.Errorf("error opening mediaapi database: %v", err) + } + + tests := []struct { + name string + fields fields + args args + want *util.JSONResponse + }{ + { + name: "upload ok", + args: args{ + ctx: context.Background(), + reqReader: strings.NewReader("test"), + cfg: cfg, + db: db, + }, + fields: fields{ + Logger: logger, + MediaMetadata: &types.MediaMetadata{ + MediaID: "1337", + UploadName: "test ok", + }, + }, + want: nil, + }, + { + name: "upload ok (exact size)", + args: args{ + ctx: context.Background(), + reqReader: strings.NewReader("testtest"), + cfg: cfg, + db: db, + }, + fields: fields{ + Logger: logger, + MediaMetadata: &types.MediaMetadata{ + MediaID: "1338", + UploadName: "test ok (exact size)", + }, + }, + want: nil, + }, + { + name: "upload not ok", + args: args{ + ctx: context.Background(), + reqReader: strings.NewReader("test test test"), + cfg: cfg, + db: db, + }, + fields: fields{ + Logger: logger, + MediaMetadata: &types.MediaMetadata{ + MediaID: "1339", + UploadName: "test fail", + }, + }, + want: requestEntityTooLargeJSONResponse(maxSize), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &uploadRequest{ + MediaMetadata: tt.fields.MediaMetadata, + Logger: tt.fields.Logger, + } + if got := r.doUpload(tt.args.ctx, tt.args.reqReader, tt.args.cfg, tt.args.db, tt.args.activeThumbnailGeneration); !reflect.DeepEqual(got, tt.want) { + t.Errorf("doUpload() = %+v, want %+v", got, tt.want) + } + }) + } +} diff --git a/setup/config/config_appservice.go b/setup/config/config_appservice.go index 4a14672e..9bd8a1b5 100644 --- a/setup/config/config_appservice.go +++ b/setup/config/config_appservice.go @@ -101,7 +101,7 @@ func (a *ApplicationService) IsInterestedInRoomID( ) bool { if namespaceSlice, ok := a.NamespaceMap["rooms"]; ok { for _, namespace := range namespaceSlice { - if namespace.RegexpObject.MatchString(roomID) { + if namespace.RegexpObject != nil && namespace.RegexpObject.MatchString(roomID) { return true } } @@ -222,6 +222,10 @@ func setupRegexps(asAPI *AppServiceAPI, derived *Derived) (err error) { case "aliases": appendExclusiveNamespaceRegexs(&exclusiveAliasStrings, namespaceSlice) } + + if err = compileNamespaceRegexes(namespaceSlice); err != nil { + return fmt.Errorf("invalid regex in appservice %q, namespace %q: %w", appservice.ID, key, err) + } } } @@ -258,18 +262,31 @@ func setupRegexps(asAPI *AppServiceAPI, derived *Derived) (err error) { func appendExclusiveNamespaceRegexs( exclusiveStrings *[]string, namespaces []ApplicationServiceNamespace, ) { - for index, namespace := range namespaces { + for _, namespace := range namespaces { if namespace.Exclusive { // We append parenthesis to later separate each regex when we compile // i.e. "app1.*", "app2.*" -> "(app1.*)|(app2.*)" *exclusiveStrings = append(*exclusiveStrings, "("+namespace.Regex+")") } - - // Compile this regex into a Regexp object for later use - namespaces[index].RegexpObject, _ = regexp.Compile(namespace.Regex) } } +// compileNamespaceRegexes turns strings into regex objects and complains +// if some of there are bad +func compileNamespaceRegexes(namespaces []ApplicationServiceNamespace) (err error) { + for index, namespace := range namespaces { + // Compile this regex into a Regexp object for later use + r, err := regexp.Compile(namespace.Regex) + if err != nil { + return fmt.Errorf("regex at namespace %d: %w", index, err) + } + + namespaces[index].RegexpObject = r + } + + return nil +} + // checkErrors checks for any configuration errors amongst the loaded // application services according to the application service spec. func checkErrors(config *AppServiceAPI, derived *Derived) (err error) { diff --git a/setup/config/config_mediaapi.go b/setup/config/config_mediaapi.go index 660a508d..0943a39e 100644 --- a/setup/config/config_mediaapi.go +++ b/setup/config/config_mediaapi.go @@ -2,6 +2,7 @@ package config import ( "fmt" + "math" ) type MediaAPI struct { @@ -57,6 +58,11 @@ func (c *MediaAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { checkNotEmpty(configErrs, "media_api.database.connection_string", string(c.Database.ConnectionString)) checkNotEmpty(configErrs, "media_api.base_path", string(c.BasePath)) + // allow "unlimited" file size + if c.MaxFileSizeBytes != nil && *c.MaxFileSizeBytes <= 0 { + unlimitedSize := FileSizeBytes(math.MaxInt64 - 1) + c.MaxFileSizeBytes = &unlimitedSize + } checkPositive(configErrs, "media_api.max_file_size_bytes", int64(*c.MaxFileSizeBytes)) checkPositive(configErrs, "media_api.max_thumbnail_generators", int64(c.MaxThumbnailGenerators))