mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-26 15:08:28 +00:00
Merge branch 'master' into neilalexander/servername
This commit is contained in:
commit
043759af60
18 changed files with 414 additions and 96 deletions
|
@ -1,7 +1,7 @@
|
|||
# Dendrite [![Build Status](https://badge.buildkite.com/4be40938ab19f2bbc4a6c6724517353ee3ec1422e279faf374.svg?branch=master)](https://buildkite.com/matrix-dot-org/dendrite) [![Dendrite](https://img.shields.io/matrix/dendrite:matrix.org.svg?label=%23dendrite%3Amatrix.org&logo=matrix&server_fqdn=matrix.org)](https://matrix.to/#/#dendrite:matrix.org) [![Dendrite Dev](https://img.shields.io/matrix/dendrite-dev:matrix.org.svg?label=%23dendrite-dev%3Amatrix.org&logo=matrix&server_fqdn=matrix.org)](https://matrix.to/#/#dendrite-dev:matrix.org)
|
||||
|
||||
Dendrite is a second-generation Matrix homeserver written in Go.
|
||||
It intends to provide an **efficient**, **reliable** and **scalable** alternative to Synapse:
|
||||
It intends to provide an **efficient**, **reliable** and **scalable** alternative to [Synapse](https://github.com/matrix-org/synapse):
|
||||
- Efficient: A small memory footprint with better baseline performance than an out-of-the-box Synapse.
|
||||
- Reliable: Implements the Matrix specification as written, using the
|
||||
[same test suite](https://github.com/matrix-org/sytest) as Synapse as well as
|
||||
|
|
|
@ -32,14 +32,18 @@ INSERT OR IGNORE INTO appservice_counters (name, last_id) VALUES('txn_id', 1);
|
|||
`
|
||||
|
||||
const selectTxnIDSQL = `
|
||||
SELECT last_id FROM appservice_counters WHERE name='txn_id';
|
||||
UPDATE appservice_counters SET last_id=last_id+1 WHERE name='txn_id';
|
||||
SELECT last_id FROM appservice_counters WHERE name='txn_id'
|
||||
`
|
||||
|
||||
const updateTxnIDSQL = `
|
||||
UPDATE appservice_counters SET last_id=last_id+1 WHERE name='txn_id'
|
||||
`
|
||||
|
||||
type txnStatements struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
selectTxnIDStmt *sql.Stmt
|
||||
updateTxnIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *txnStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
||||
|
@ -54,6 +58,10 @@ func (s *txnStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
if s.updateTxnIDStmt, err = db.Prepare(updateTxnIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -63,6 +71,11 @@ func (s *txnStatements) selectTxnID(
|
|||
) (txnID int, err error) {
|
||||
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
err := s.selectTxnIDStmt.QueryRowContext(ctx).Scan(&txnID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = s.updateTxnIDStmt.ExecContext(ctx)
|
||||
return err
|
||||
})
|
||||
return
|
||||
|
|
|
@ -37,19 +37,21 @@ runtime config should come from. The mounted folder must contain:
|
|||
To generate keys:
|
||||
|
||||
```
|
||||
go run github.com/matrix-org/dendrite/cmd/generate-keys \
|
||||
--private-key=matrix_key.pem \
|
||||
--tls-cert=server.crt \
|
||||
--tls-key=server.key
|
||||
docker run --rm --entrypoint="" \
|
||||
-v $(pwd):/mnt \
|
||||
matrixdotorg/dendrite-monolith:latest \
|
||||
/usr/bin/generate-keys \
|
||||
-private-key /mnt/matrix_key.pem \
|
||||
-tls-cert /mnt/server.crt \
|
||||
-tls-key /mnt/server.key
|
||||
```
|
||||
|
||||
The key files will now exist in your current working directory, and can be mounted into place.
|
||||
|
||||
## Starting Dendrite as a monolith deployment
|
||||
|
||||
Create your config based on the `dendrite.yaml` configuration file in the `docker/config`
|
||||
folder in the [Dendrite repository](https://github.com/matrix-org/dendrite). Additionally,
|
||||
make the following changes to the configuration:
|
||||
|
||||
- Enable Naffka: `use_naffka: true`
|
||||
folder in the [Dendrite repository](https://github.com/matrix-org/dendrite).
|
||||
|
||||
Once in place, start the PostgreSQL dependency:
|
||||
|
||||
|
|
|
@ -10,7 +10,6 @@ import (
|
|||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
|
@ -37,15 +36,14 @@ import (
|
|||
userapiAPI "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
"go.uber.org/atomic"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
|
||||
pineconeMulticast "github.com/matrix-org/pinecone/multicast"
|
||||
"github.com/matrix-org/pinecone/router"
|
||||
pineconeRouter "github.com/matrix-org/pinecone/router"
|
||||
pineconeSessions "github.com/matrix-org/pinecone/sessions"
|
||||
"github.com/matrix-org/pinecone/types"
|
||||
pineconeTypes "github.com/matrix-org/pinecone/types"
|
||||
|
||||
_ "golang.org/x/mobile/bind"
|
||||
)
|
||||
|
@ -57,19 +55,19 @@ const (
|
|||
)
|
||||
|
||||
type DendriteMonolith struct {
|
||||
logger logrus.Logger
|
||||
PineconeRouter *pineconeRouter.Router
|
||||
PineconeMulticast *pineconeMulticast.Multicast
|
||||
PineconeQUIC *pineconeSessions.Sessions
|
||||
StorageDirectory string
|
||||
CacheDirectory string
|
||||
staticPeerURI string
|
||||
staticPeerMutex sync.RWMutex
|
||||
staticPeerAttempts atomic.Uint32
|
||||
listener net.Listener
|
||||
httpServer *http.Server
|
||||
processContext *process.ProcessContext
|
||||
userAPI userapiAPI.UserInternalAPI
|
||||
logger logrus.Logger
|
||||
PineconeRouter *pineconeRouter.Router
|
||||
PineconeMulticast *pineconeMulticast.Multicast
|
||||
PineconeQUIC *pineconeSessions.Sessions
|
||||
StorageDirectory string
|
||||
CacheDirectory string
|
||||
staticPeerURI string
|
||||
staticPeerMutex sync.RWMutex
|
||||
staticPeerAttempt chan struct{}
|
||||
listener net.Listener
|
||||
httpServer *http.Server
|
||||
processContext *process.ProcessContext
|
||||
userAPI userapiAPI.UserInternalAPI
|
||||
}
|
||||
|
||||
func (m *DendriteMonolith) BaseURL() string {
|
||||
|
@ -99,7 +97,9 @@ func (m *DendriteMonolith) SetStaticPeer(uri string) {
|
|||
m.staticPeerMutex.Unlock()
|
||||
m.DisconnectType(pineconeRouter.PeerTypeRemote)
|
||||
if uri != "" {
|
||||
go m.staticPeerConnect()
|
||||
go func() {
|
||||
m.staticPeerAttempt <- struct{}{}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -195,17 +195,27 @@ func (m *DendriteMonolith) RegisterDevice(localpart, deviceID string) (string, e
|
|||
}
|
||||
|
||||
func (m *DendriteMonolith) staticPeerConnect() {
|
||||
m.staticPeerMutex.RLock()
|
||||
uri := m.staticPeerURI
|
||||
m.staticPeerMutex.RUnlock()
|
||||
if uri == "" {
|
||||
return
|
||||
attempt := func() {
|
||||
if m.PineconeRouter.PeerCount(router.PeerTypeRemote) == 0 {
|
||||
m.staticPeerMutex.RLock()
|
||||
uri := m.staticPeerURI
|
||||
m.staticPeerMutex.RUnlock()
|
||||
if uri == "" {
|
||||
return
|
||||
}
|
||||
if err := conn.ConnectToPeer(m.PineconeRouter, uri); err != nil {
|
||||
logrus.WithError(err).Error("Failed to connect to static peer")
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := conn.ConnectToPeer(m.PineconeRouter, uri); err != nil {
|
||||
exp := time.Second * time.Duration(math.Exp2(float64(m.staticPeerAttempts.Inc())))
|
||||
time.AfterFunc(exp, m.staticPeerConnect)
|
||||
} else {
|
||||
m.staticPeerAttempts.Store(0)
|
||||
for {
|
||||
select {
|
||||
case <-m.processContext.Context().Done():
|
||||
case <-m.staticPeerAttempt:
|
||||
attempt()
|
||||
case <-time.After(time.Second * 5):
|
||||
attempt()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -248,13 +258,6 @@ func (m *DendriteMonolith) Start() {
|
|||
m.PineconeQUIC = pineconeSessions.NewSessions(logger, m.PineconeRouter)
|
||||
m.PineconeMulticast = pineconeMulticast.NewMulticast(logger, m.PineconeRouter)
|
||||
|
||||
m.PineconeRouter.SetDisconnectedCallback(func(port pineconeTypes.SwitchPortID, public pineconeTypes.PublicKey, peertype int, err error) {
|
||||
if peertype == pineconeRouter.PeerTypeRemote {
|
||||
m.staticPeerAttempts.Store(0)
|
||||
time.AfterFunc(time.Second, m.staticPeerConnect)
|
||||
}
|
||||
})
|
||||
|
||||
prefix := hex.EncodeToString(pk)
|
||||
cfg := &config.Dendrite{}
|
||||
cfg.Defaults()
|
||||
|
@ -359,8 +362,12 @@ func (m *DendriteMonolith) Start() {
|
|||
},
|
||||
Handler: h2c.NewHandler(pMux, h2s),
|
||||
}
|
||||
|
||||
m.processContext = base.ProcessContext
|
||||
|
||||
m.staticPeerAttempt = make(chan struct{}, 1)
|
||||
go m.staticPeerConnect()
|
||||
|
||||
go func() {
|
||||
m.logger.Info("Listening on ", cfg.Global.ServerName)
|
||||
m.logger.Fatal(m.httpServer.Serve(m.PineconeQUIC))
|
||||
|
|
|
@ -69,7 +69,7 @@ func GetAccountData(
|
|||
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusNotFound,
|
||||
JSON: jsonerror.Forbidden("data not found"),
|
||||
JSON: jsonerror.NotFound("data not found"),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -18,13 +18,18 @@ import (
|
|||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/accounts"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
const usage = `Usage: %s
|
||||
|
@ -33,7 +38,15 @@ Creates a new user account on the homeserver.
|
|||
|
||||
Example:
|
||||
|
||||
./create-account --config dendrite.yaml --username alice --password foobarbaz
|
||||
# provide password by parameter
|
||||
%s --config dendrite.yaml -username alice -password foobarbaz
|
||||
# use password from file
|
||||
%s --config dendrite.yaml -username alice -passwordfile my.pass
|
||||
# ask user to provide password
|
||||
%s --config dendrite.yaml -username alice -ask-pass
|
||||
# read password from stdin
|
||||
%s --config dendrite.yaml -username alice -passwordstdin < my.pass
|
||||
cat my.pass | %s --config dendrite.yaml -username alice -passwordstdin
|
||||
|
||||
Arguments:
|
||||
|
||||
|
@ -42,11 +55,15 @@ Arguments:
|
|||
var (
|
||||
username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')")
|
||||
password = flag.String("password", "", "The password to associate with the account (optional, account will be password-less if not specified)")
|
||||
pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)")
|
||||
pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin")
|
||||
askPass = flag.Bool("ask-pass", false, "Ask for the password to use")
|
||||
)
|
||||
|
||||
func main() {
|
||||
name := os.Args[0]
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, usage, os.Args[0])
|
||||
_, _ = fmt.Fprintf(os.Stderr, usage, name, name, name, name, name, name)
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
cfg := setup.ParseFlags(true)
|
||||
|
@ -56,6 +73,8 @@ func main() {
|
|||
os.Exit(1)
|
||||
}
|
||||
|
||||
pass := getPassword(password, pwdFile, pwdStdin, askPass, os.Stdin)
|
||||
|
||||
accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{
|
||||
ConnectionString: cfg.UserAPI.AccountDatabase.ConnectionString,
|
||||
}, cfg.Global.ServerName, bcrypt.DefaultCost, cfg.UserAPI.OpenIDTokenLifetimeMS)
|
||||
|
@ -63,10 +82,61 @@ func main() {
|
|||
logrus.Fatalln("Failed to connect to the database:", err.Error())
|
||||
}
|
||||
|
||||
_, err = accountDB.CreateAccount(context.Background(), *username, *password, "")
|
||||
_, err = accountDB.CreateAccount(context.Background(), *username, pass, "")
|
||||
if err != nil {
|
||||
logrus.Fatalln("Failed to create the account:", err.Error())
|
||||
}
|
||||
|
||||
logrus.Infoln("Created account", *username)
|
||||
}
|
||||
|
||||
func getPassword(password, pwdFile *string, pwdStdin, askPass *bool, r io.Reader) string {
|
||||
// no password option set, use empty password
|
||||
if password == nil && pwdFile == nil && pwdStdin == nil && askPass == nil {
|
||||
return ""
|
||||
}
|
||||
// password defined as parameter
|
||||
if password != nil && *password != "" {
|
||||
return *password
|
||||
}
|
||||
|
||||
// read password from file
|
||||
if pwdFile != nil && *pwdFile != "" {
|
||||
pw, err := ioutil.ReadFile(*pwdFile)
|
||||
if err != nil {
|
||||
logrus.Fatalln("Unable to read password from file:", err)
|
||||
}
|
||||
return strings.TrimSpace(string(pw))
|
||||
}
|
||||
|
||||
// read password from stdin
|
||||
if pwdStdin != nil && *pwdStdin {
|
||||
data, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
logrus.Fatalln("Unable to read password from stdin:", err)
|
||||
}
|
||||
return strings.TrimSpace(string(data))
|
||||
}
|
||||
|
||||
// ask the user to provide the password
|
||||
if *askPass {
|
||||
fmt.Print("Enter Password: ")
|
||||
bytePassword, err := term.ReadPassword(syscall.Stdin)
|
||||
if err != nil {
|
||||
logrus.Fatalln("Unable to read password:", err)
|
||||
}
|
||||
fmt.Println()
|
||||
fmt.Print("Confirm Password: ")
|
||||
bytePassword2, err := term.ReadPassword(syscall.Stdin)
|
||||
if err != nil {
|
||||
logrus.Fatalln("Unable to read password:", err)
|
||||
}
|
||||
fmt.Println()
|
||||
if strings.TrimSpace(string(bytePassword)) != strings.TrimSpace(string(bytePassword2)) {
|
||||
logrus.Fatalln("Entered passwords don't match")
|
||||
}
|
||||
return strings.TrimSpace(string(bytePassword))
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
|
62
cmd/create-account/main_test.go
Normal file
62
cmd/create-account/main_test.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_getPassword(t *testing.T) {
|
||||
type args struct {
|
||||
password *string
|
||||
pwdFile *string
|
||||
pwdStdin *bool
|
||||
askPass *bool
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
pass := "mySecretPass"
|
||||
passwordFile := "testdata/my.pass"
|
||||
passwordStdin := true
|
||||
reader := &bytes.Buffer{}
|
||||
_, err := reader.WriteString(pass)
|
||||
if err != nil {
|
||||
t.Errorf("unable to write to buffer: %+v", err)
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no password defined",
|
||||
args: args{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "password defined",
|
||||
args: args{password: &pass},
|
||||
want: pass,
|
||||
},
|
||||
{
|
||||
name: "pwdFile defined",
|
||||
args: args{pwdFile: &passwordFile},
|
||||
want: pass,
|
||||
},
|
||||
{
|
||||
name: "read pass from stdin defined",
|
||||
args: args{
|
||||
pwdStdin: &passwordStdin,
|
||||
reader: reader,
|
||||
},
|
||||
want: pass,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := getPassword(tt.args.password, tt.args.pwdFile, tt.args.pwdStdin, tt.args.askPass, tt.args.reader); got != tt.want {
|
||||
t.Errorf("getPassword() = '%v', want '%v'", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
1
cmd/create-account/testdata/my.pass
vendored
Normal file
1
cmd/create-account/testdata/my.pass
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
mySecretPass
|
|
@ -23,7 +23,6 @@ import (
|
|||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
|
@ -48,12 +47,11 @@ import (
|
|||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/userapi"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"go.uber.org/atomic"
|
||||
|
||||
pineconeMulticast "github.com/matrix-org/pinecone/multicast"
|
||||
"github.com/matrix-org/pinecone/router"
|
||||
pineconeRouter "github.com/matrix-org/pinecone/router"
|
||||
pineconeSessions "github.com/matrix-org/pinecone/sessions"
|
||||
pineconeTypes "github.com/matrix-org/pinecone/types"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
@ -123,27 +121,23 @@ func main() {
|
|||
pMulticast := pineconeMulticast.NewMulticast(logger, pRouter)
|
||||
pMulticast.Start()
|
||||
|
||||
var staticPeerAttempts atomic.Uint32
|
||||
var connectToStaticPeer func()
|
||||
connectToStaticPeer = func() {
|
||||
uri := *instancePeer
|
||||
if uri == "" {
|
||||
return
|
||||
connectToStaticPeer := func() {
|
||||
attempt := func() {
|
||||
if pRouter.PeerCount(router.PeerTypeRemote) == 0 {
|
||||
uri := *instancePeer
|
||||
if uri == "" {
|
||||
return
|
||||
}
|
||||
if err := conn.ConnectToPeer(pRouter, uri); err != nil {
|
||||
logrus.WithError(err).Error("Failed to connect to static peer")
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := conn.ConnectToPeer(pRouter, uri); err != nil {
|
||||
exp := time.Second * time.Duration(math.Exp2(float64(staticPeerAttempts.Inc())))
|
||||
time.AfterFunc(exp, connectToStaticPeer)
|
||||
} else {
|
||||
staticPeerAttempts.Store(0)
|
||||
for {
|
||||
attempt()
|
||||
time.Sleep(time.Second * 5)
|
||||
}
|
||||
}
|
||||
pRouter.SetDisconnectedCallback(func(port pineconeTypes.SwitchPortID, public pineconeTypes.PublicKey, peertype int, err error) {
|
||||
if peertype == pineconeRouter.PeerTypeRemote && err != nil {
|
||||
staticPeerAttempts.Store(0)
|
||||
time.AfterFunc(time.Second, connectToStaticPeer)
|
||||
}
|
||||
})
|
||||
go connectToStaticPeer()
|
||||
|
||||
cfg := &config.Dendrite{}
|
||||
cfg.Defaults()
|
||||
|
@ -257,6 +251,7 @@ func main() {
|
|||
Handler: pMux,
|
||||
}
|
||||
|
||||
go connectToStaticPeer()
|
||||
go func() {
|
||||
pubkey := pRouter.PublicKey()
|
||||
logrus.Info("Listening on ", hex.EncodeToString(pubkey[:]))
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
server {
|
||||
listen 443 ssl; # IPv4
|
||||
listen [::]:443; # IPv6
|
||||
listen [::]:443 ssl; # IPv6
|
||||
server_name my.hostname.com;
|
||||
|
||||
ssl_certificate /path/to/fullchain.pem;
|
||||
|
@ -16,6 +16,9 @@ server {
|
|||
}
|
||||
|
||||
location /.well-known/matrix/client {
|
||||
# If your sever_name here doesn't match your matrix homeserver URL
|
||||
# (e.g. hostname.com as server_name and matrix.hostname.com as homeserver URL)
|
||||
# add_header Access-Control-Allow-Origin '*';
|
||||
return 200 '{ "m.homeserver": { "base_url": "https://my.hostname.com" } }';
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
server {
|
||||
listen 443 ssl; # IPv4
|
||||
listen [::]:443; # IPv6
|
||||
listen [::]:443 ssl; # IPv6
|
||||
server_name my.hostname.com;
|
||||
|
||||
ssl_certificate /path/to/fullchain.pem;
|
||||
|
@ -16,6 +16,9 @@ server {
|
|||
}
|
||||
|
||||
location /.well-known/matrix/client {
|
||||
# If your sever_name here doesn't match your matrix homeserver URL
|
||||
# (e.g. hostname.com as server_name and matrix.hostname.com as homeserver URL)
|
||||
# add_header Access-Control-Allow-Origin '*';
|
||||
return 200 '{ "m.homeserver": { "base_url": "https://my.hostname.com" } }';
|
||||
}
|
||||
|
||||
|
|
2
go.mod
2
go.mod
|
@ -25,7 +25,7 @@ require (
|
|||
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd
|
||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20210525110027-8cb7699aa64a
|
||||
github.com/matrix-org/naffka v0.0.0-20201009174903-d26a3b9cb161
|
||||
github.com/matrix-org/pinecone v0.0.0-20210602111459-5cb0e6aa1a6a
|
||||
github.com/matrix-org/pinecone v0.0.0-20210614122540-33ce3bd0f3ac
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||
github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb
|
||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
||||
|
|
4
go.sum
4
go.sum
|
@ -706,8 +706,8 @@ github.com/matrix-org/gomatrixserverlib v0.0.0-20210525110027-8cb7699aa64a h1:pV
|
|||
github.com/matrix-org/gomatrixserverlib v0.0.0-20210525110027-8cb7699aa64a/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
||||
github.com/matrix-org/naffka v0.0.0-20201009174903-d26a3b9cb161 h1:h1XVh05pLoC+nJjP3GIpj5wUsuC8WdHP3He0RTkRJTs=
|
||||
github.com/matrix-org/naffka v0.0.0-20201009174903-d26a3b9cb161/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE=
|
||||
github.com/matrix-org/pinecone v0.0.0-20210602111459-5cb0e6aa1a6a h1:BE/cfpyHO2ua1BK4Tibr+2oZCV3H1mC9G7g7Yvl1AmM=
|
||||
github.com/matrix-org/pinecone v0.0.0-20210602111459-5cb0e6aa1a6a/go.mod h1:UQzJS6UVyVwfkr+RLrdvBB1vLyECqe3fLYNcbRxv8SA=
|
||||
github.com/matrix-org/pinecone v0.0.0-20210614122540-33ce3bd0f3ac h1:qgEfJzulYUVDGh1PGzeGxYMGDtKSxMS+6eQG6E37pgM=
|
||||
github.com/matrix-org/pinecone v0.0.0-20210614122540-33ce3bd0f3ac/go.mod h1:UQzJS6UVyVwfkr+RLrdvBB1vLyECqe3fLYNcbRxv8SA=
|
||||
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk=
|
||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U=
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
|
@ -47,7 +48,7 @@ const upsertKeysSQL = "" +
|
|||
" DO UPDATE SET key_json = $6"
|
||||
|
||||
const selectKeysSQL = "" +
|
||||
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
|
||||
"SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);"
|
||||
|
||||
const selectKeysCountSQL = "" +
|
||||
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
|
||||
|
@ -94,29 +95,22 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
|||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID)
|
||||
rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID, pq.Array(keyIDsWithAlgorithms))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed")
|
||||
|
||||
wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
|
||||
for _, ka := range keyIDsWithAlgorithms {
|
||||
wantSet[ka] = true
|
||||
}
|
||||
|
||||
result := make(map[string]json.RawMessage)
|
||||
var (
|
||||
algorithmWithID string
|
||||
keyJSONStr string
|
||||
)
|
||||
for rows.Next() {
|
||||
var keyID string
|
||||
var algorithm string
|
||||
var keyJSONStr string
|
||||
if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil {
|
||||
if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyIDWithAlgo := algorithm + ":" + keyID
|
||||
if wantSet[keyIDWithAlgo] {
|
||||
result[keyIDWithAlgo] = json.RawMessage(keyJSONStr)
|
||||
}
|
||||
result[algorithmWithID] = json.RawMessage(keyJSONStr)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
|
|
@ -147,7 +147,20 @@ func (r *uploadRequest) doUpload(
|
|||
// r.storeFileAndMetadata(ctx, tmpDir, ...)
|
||||
// before you return from doUpload else we will leak a temp file. We could make this nicer with a `WithTransaction` style of
|
||||
// nested function to guarantee either storage or cleanup.
|
||||
hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, reqReader, cfg.AbsBasePath)
|
||||
|
||||
// should not happen, but prevents any int overflows
|
||||
if cfg.MaxFileSizeBytes != nil && *cfg.MaxFileSizeBytes+1 <= 0 {
|
||||
r.Logger.WithFields(log.Fields{
|
||||
"MaxFileSizeBytes": *cfg.MaxFileSizeBytes + 1,
|
||||
}).Error("Error while transferring file, configured max_file_size_bytes overflows int64")
|
||||
return &util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.Unknown("Failed to upload"),
|
||||
}
|
||||
}
|
||||
|
||||
lr := io.LimitReader(reqReader, int64(*cfg.MaxFileSizeBytes)+1)
|
||||
hash, bytesWritten, tmpDir, err := fileutils.WriteTempFile(ctx, lr, cfg.AbsBasePath)
|
||||
if err != nil {
|
||||
r.Logger.WithError(err).WithFields(log.Fields{
|
||||
"MaxFileSizeBytes": *cfg.MaxFileSizeBytes,
|
||||
|
|
132
mediaapi/routing/upload_test.go
Normal file
132
mediaapi/routing/upload_test.go
Normal file
|
@ -0,0 +1,132 @@
|
|||
package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/mediaapi/fileutils"
|
||||
"github.com/matrix-org/dendrite/mediaapi/storage"
|
||||
"github.com/matrix-org/dendrite/mediaapi/types"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func Test_uploadRequest_doUpload(t *testing.T) {
|
||||
type fields struct {
|
||||
MediaMetadata *types.MediaMetadata
|
||||
Logger *log.Entry
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
reqReader io.Reader
|
||||
cfg *config.MediaAPI
|
||||
db storage.Database
|
||||
activeThumbnailGeneration *types.ActiveThumbnailGeneration
|
||||
}
|
||||
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Errorf("failed to get current working directory: %v", err)
|
||||
}
|
||||
|
||||
maxSize := config.FileSizeBytes(8)
|
||||
logger := log.New().WithField("mediaapi", "test")
|
||||
testdataPath := filepath.Join(wd, "./testdata")
|
||||
|
||||
cfg := &config.MediaAPI{
|
||||
MaxFileSizeBytes: &maxSize,
|
||||
BasePath: config.Path(testdataPath),
|
||||
AbsBasePath: config.Path(testdataPath),
|
||||
DynamicThumbnails: false,
|
||||
}
|
||||
|
||||
// create testdata folder and remove when done
|
||||
_ = os.Mkdir(testdataPath, os.ModePerm)
|
||||
defer fileutils.RemoveDir(types.Path(testdataPath), nil)
|
||||
|
||||
db, err := storage.Open(&config.DatabaseOptions{
|
||||
ConnectionString: "file::memory:?cache=shared",
|
||||
MaxOpenConnections: 100,
|
||||
MaxIdleConnections: 2,
|
||||
ConnMaxLifetimeSeconds: -1,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("error opening mediaapi database: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want *util.JSONResponse
|
||||
}{
|
||||
{
|
||||
name: "upload ok",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
reqReader: strings.NewReader("test"),
|
||||
cfg: cfg,
|
||||
db: db,
|
||||
},
|
||||
fields: fields{
|
||||
Logger: logger,
|
||||
MediaMetadata: &types.MediaMetadata{
|
||||
MediaID: "1337",
|
||||
UploadName: "test ok",
|
||||
},
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "upload ok (exact size)",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
reqReader: strings.NewReader("testtest"),
|
||||
cfg: cfg,
|
||||
db: db,
|
||||
},
|
||||
fields: fields{
|
||||
Logger: logger,
|
||||
MediaMetadata: &types.MediaMetadata{
|
||||
MediaID: "1338",
|
||||
UploadName: "test ok (exact size)",
|
||||
},
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "upload not ok",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
reqReader: strings.NewReader("test test test"),
|
||||
cfg: cfg,
|
||||
db: db,
|
||||
},
|
||||
fields: fields{
|
||||
Logger: logger,
|
||||
MediaMetadata: &types.MediaMetadata{
|
||||
MediaID: "1339",
|
||||
UploadName: "test fail",
|
||||
},
|
||||
},
|
||||
want: requestEntityTooLargeJSONResponse(maxSize),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &uploadRequest{
|
||||
MediaMetadata: tt.fields.MediaMetadata,
|
||||
Logger: tt.fields.Logger,
|
||||
}
|
||||
if got := r.doUpload(tt.args.ctx, tt.args.reqReader, tt.args.cfg, tt.args.db, tt.args.activeThumbnailGeneration); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("doUpload() = %+v, want %+v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -101,7 +101,7 @@ func (a *ApplicationService) IsInterestedInRoomID(
|
|||
) bool {
|
||||
if namespaceSlice, ok := a.NamespaceMap["rooms"]; ok {
|
||||
for _, namespace := range namespaceSlice {
|
||||
if namespace.RegexpObject.MatchString(roomID) {
|
||||
if namespace.RegexpObject != nil && namespace.RegexpObject.MatchString(roomID) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
@ -222,6 +222,10 @@ func setupRegexps(asAPI *AppServiceAPI, derived *Derived) (err error) {
|
|||
case "aliases":
|
||||
appendExclusiveNamespaceRegexs(&exclusiveAliasStrings, namespaceSlice)
|
||||
}
|
||||
|
||||
if err = compileNamespaceRegexes(namespaceSlice); err != nil {
|
||||
return fmt.Errorf("invalid regex in appservice %q, namespace %q: %w", appservice.ID, key, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -258,18 +262,31 @@ func setupRegexps(asAPI *AppServiceAPI, derived *Derived) (err error) {
|
|||
func appendExclusiveNamespaceRegexs(
|
||||
exclusiveStrings *[]string, namespaces []ApplicationServiceNamespace,
|
||||
) {
|
||||
for index, namespace := range namespaces {
|
||||
for _, namespace := range namespaces {
|
||||
if namespace.Exclusive {
|
||||
// We append parenthesis to later separate each regex when we compile
|
||||
// i.e. "app1.*", "app2.*" -> "(app1.*)|(app2.*)"
|
||||
*exclusiveStrings = append(*exclusiveStrings, "("+namespace.Regex+")")
|
||||
}
|
||||
|
||||
// Compile this regex into a Regexp object for later use
|
||||
namespaces[index].RegexpObject, _ = regexp.Compile(namespace.Regex)
|
||||
}
|
||||
}
|
||||
|
||||
// compileNamespaceRegexes turns strings into regex objects and complains
|
||||
// if some of there are bad
|
||||
func compileNamespaceRegexes(namespaces []ApplicationServiceNamespace) (err error) {
|
||||
for index, namespace := range namespaces {
|
||||
// Compile this regex into a Regexp object for later use
|
||||
r, err := regexp.Compile(namespace.Regex)
|
||||
if err != nil {
|
||||
return fmt.Errorf("regex at namespace %d: %w", index, err)
|
||||
}
|
||||
|
||||
namespaces[index].RegexpObject = r
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkErrors checks for any configuration errors amongst the loaded
|
||||
// application services according to the application service spec.
|
||||
func checkErrors(config *AppServiceAPI, derived *Derived) (err error) {
|
||||
|
|
|
@ -2,6 +2,7 @@ package config
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
)
|
||||
|
||||
type MediaAPI struct {
|
||||
|
@ -57,6 +58,11 @@ func (c *MediaAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
|||
checkNotEmpty(configErrs, "media_api.database.connection_string", string(c.Database.ConnectionString))
|
||||
|
||||
checkNotEmpty(configErrs, "media_api.base_path", string(c.BasePath))
|
||||
// allow "unlimited" file size
|
||||
if c.MaxFileSizeBytes != nil && *c.MaxFileSizeBytes <= 0 {
|
||||
unlimitedSize := FileSizeBytes(math.MaxInt64 - 1)
|
||||
c.MaxFileSizeBytes = &unlimitedSize
|
||||
}
|
||||
checkPositive(configErrs, "media_api.max_file_size_bytes", int64(*c.MaxFileSizeBytes))
|
||||
checkPositive(configErrs, "media_api.max_thumbnail_generators", int64(c.MaxThumbnailGenerators))
|
||||
|
||||
|
|
Loading…
Reference in a new issue