mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-07-29 12:42:46 +00:00
Merge branch 'master' into add-nats-support
This commit is contained in:
commit
e2e1a966e1
41 changed files with 959 additions and 714 deletions
|
@ -1,53 +0,0 @@
|
|||
// Copyright 2019 Google LLC
|
||||
//
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file or at
|
||||
// https://developers.google.com/open-source/licenses/bsd
|
||||
//
|
||||
// Original code from https://github.com/FiloSottile/age/blob/bbab440e198a4d67ba78591176c7853e62d29e04/internal/age/ssh.go
|
||||
|
||||
package convert
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/sha512"
|
||||
"math/big"
|
||||
|
||||
"golang.org/x/crypto/curve25519"
|
||||
)
|
||||
|
||||
var curve25519P, _ = new(big.Int).SetString("57896044618658097711785492504343953926634992332820282019728792003956564819949", 10)
|
||||
|
||||
func Ed25519PrivateKeyToCurve25519(pk ed25519.PrivateKey) []byte {
|
||||
h := sha512.New()
|
||||
_, _ = h.Write(pk.Seed())
|
||||
out := h.Sum(nil)
|
||||
return out[:curve25519.ScalarSize]
|
||||
}
|
||||
|
||||
func Ed25519PublicKeyToCurve25519(pk ed25519.PublicKey) []byte {
|
||||
// ed25519.PublicKey is a little endian representation of the y-coordinate,
|
||||
// with the most significant bit set based on the sign of the x-coordinate.
|
||||
bigEndianY := make([]byte, ed25519.PublicKeySize)
|
||||
for i, b := range pk {
|
||||
bigEndianY[ed25519.PublicKeySize-i-1] = b
|
||||
}
|
||||
bigEndianY[0] &= 0b0111_1111
|
||||
|
||||
// The Montgomery u-coordinate is derived through the bilinear map
|
||||
// u = (1 + y) / (1 - y)
|
||||
// See https://blog.filippo.io/using-ed25519-keys-for-encryption.
|
||||
y := new(big.Int).SetBytes(bigEndianY)
|
||||
denom := big.NewInt(1)
|
||||
denom.ModInverse(denom.Sub(denom, y), curve25519P) // 1 / (1 - y)
|
||||
u := y.Mul(y.Add(y, big.NewInt(1)), denom)
|
||||
u.Mod(u, curve25519P)
|
||||
|
||||
out := make([]byte, curve25519.PointSize)
|
||||
uBytes := u.Bytes()
|
||||
for i, b := range uBytes {
|
||||
out[len(uBytes)-i-1] = b
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
|
@ -1,51 +0,0 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ed25519"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/curve25519"
|
||||
)
|
||||
|
||||
func TestKeyConversion(t *testing.T) {
|
||||
edPub, edPriv, err := ed25519.GenerateKey(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("Signing public:", hex.EncodeToString(edPub))
|
||||
t.Log("Signing private:", hex.EncodeToString(edPriv))
|
||||
|
||||
cuPriv := Ed25519PrivateKeyToCurve25519(edPriv)
|
||||
t.Log("Encryption private:", hex.EncodeToString(cuPriv))
|
||||
|
||||
cuPub := Ed25519PublicKeyToCurve25519(edPub)
|
||||
t.Log("Converted encryption public:", hex.EncodeToString(cuPub))
|
||||
|
||||
var realPub, realPriv [32]byte
|
||||
copy(realPriv[:32], cuPriv[:32])
|
||||
curve25519.ScalarBaseMult(&realPub, &realPriv)
|
||||
t.Log("Scalar-multed encryption public:", hex.EncodeToString(realPub[:]))
|
||||
|
||||
if !bytes.Equal(realPriv[:], cuPriv[:]) {
|
||||
t.Fatal("Private keys should be equal (this means the test is broken)")
|
||||
}
|
||||
if !bytes.Equal(realPub[:], cuPub[:]) {
|
||||
t.Fatal("Public keys should be equal")
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
// +build riotweb
|
||||
// +build elementweb
|
||||
|
||||
package embed
|
||||
|
||||
|
@ -12,8 +12,8 @@ import (
|
|||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// From within the Riot Web directory:
|
||||
// go run github.com/mjibson/esc -o /path/to/dendrite/internal/embed/fs_riotweb.go -private -pkg embed .
|
||||
// From within the Element Web directory:
|
||||
// go run github.com/mjibson/esc -o /path/to/dendrite/internal/embed/fs_elementweb.go -private -pkg embed .
|
||||
|
||||
var cssFile = regexp.MustCompile("\\.css$")
|
||||
var jsFile = regexp.MustCompile("\\.js$")
|
||||
|
@ -68,16 +68,16 @@ func Embed(rootMux *mux.Router, listenPort int, serverName string) {
|
|||
}
|
||||
js, _ := sjson.SetBytes(buf, "default_server_config.m\\.homeserver.base_url", url)
|
||||
js, _ = sjson.SetBytes(js, "default_server_config.m\\.homeserver.server_name", serverName)
|
||||
js, _ = sjson.SetBytes(js, "brand", fmt.Sprintf("Riot %s", serverName))
|
||||
js, _ = sjson.SetBytes(js, "brand", fmt.Sprintf("Element %s", serverName))
|
||||
js, _ = sjson.SetBytes(js, "disable_guests", true)
|
||||
js, _ = sjson.SetBytes(js, "disable_3pid_login", true)
|
||||
js, _ = sjson.DeleteBytes(js, "welcomeUserId")
|
||||
_, _ = w.Write(js)
|
||||
})
|
||||
|
||||
fmt.Println("*-------------------------------*")
|
||||
fmt.Println("| This build includes Riot Web! |")
|
||||
fmt.Println("*-------------------------------*")
|
||||
fmt.Println("*----------------------------------*")
|
||||
fmt.Println("| This build includes Element Web! |")
|
||||
fmt.Println("*----------------------------------*")
|
||||
fmt.Println("Point your browser to:", url)
|
||||
fmt.Println()
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
// +build !riotweb
|
||||
// +build !elementweb
|
||||
|
||||
package embed
|
||||
|
||||
|
|
|
@ -56,21 +56,23 @@ func main() {
|
|||
flag.Parse()
|
||||
internal.SetupPprof()
|
||||
|
||||
ygg, err := yggconn.Setup(*instanceName, ".")
|
||||
ygg, err := yggconn.Setup(*instanceName, ".", *instancePeer)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ygg.SetMulticastEnabled(true)
|
||||
if instancePeer != nil && *instancePeer != "" {
|
||||
if err = ygg.SetStaticPeer(*instancePeer); err != nil {
|
||||
logrus.WithError(err).Error("Failed to set static peer")
|
||||
/*
|
||||
ygg.SetMulticastEnabled(true)
|
||||
if instancePeer != nil && *instancePeer != "" {
|
||||
if err = ygg.SetStaticPeer(*instancePeer); err != nil {
|
||||
logrus.WithError(err).Error("Failed to set static peer")
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
cfg := &config.Dendrite{}
|
||||
cfg.Defaults()
|
||||
cfg.Global.ServerName = gomatrixserverlib.ServerName(ygg.DerivedServerName())
|
||||
cfg.Global.PrivateKey = ygg.SigningPrivateKey()
|
||||
cfg.Global.PrivateKey = ygg.PrivateKey()
|
||||
cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID)
|
||||
cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", *instanceName))
|
||||
cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", *instanceName))
|
||||
|
@ -116,18 +118,6 @@ func main() {
|
|||
base, federation, rsAPI, keyRing, true,
|
||||
)
|
||||
|
||||
ygg.SetSessionFunc(func(address string) {
|
||||
req := &api.PerformServersAliveRequest{
|
||||
Servers: []gomatrixserverlib.ServerName{
|
||||
gomatrixserverlib.ServerName(address),
|
||||
},
|
||||
}
|
||||
res := &api.PerformServersAliveResponse{}
|
||||
if err := fsAPI.PerformServersAlive(context.TODO(), req, res); err != nil {
|
||||
logrus.WithError(err).Error("Failed to send wake-up message to newly connected node")
|
||||
}
|
||||
})
|
||||
|
||||
rsComponent.SetFederationSenderAPI(fsAPI)
|
||||
|
||||
monolith := setup.Monolith{
|
||||
|
|
|
@ -51,7 +51,6 @@ func (n *Node) CreateFederationClient(
|
|||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
DialContext: n.DialerContext,
|
||||
TLSClientConfig: n.tlsConfig,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
|
|
@ -17,7 +17,6 @@ package yggconn
|
|||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/tls"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
@ -26,60 +25,48 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/convert"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"go.uber.org/atomic"
|
||||
"github.com/neilalexander/utp"
|
||||
|
||||
ironwoodtypes "github.com/Arceliar/ironwood/types"
|
||||
yggdrasilconfig "github.com/yggdrasil-network/yggdrasil-go/src/config"
|
||||
yggdrasilcore "github.com/yggdrasil-network/yggdrasil-go/src/core"
|
||||
yggdrasildefaults "github.com/yggdrasil-network/yggdrasil-go/src/defaults"
|
||||
yggdrasilmulticast "github.com/yggdrasil-network/yggdrasil-go/src/multicast"
|
||||
"github.com/yggdrasil-network/yggdrasil-go/src/yggdrasil"
|
||||
|
||||
gologme "github.com/gologme/log"
|
||||
)
|
||||
|
||||
type Node struct {
|
||||
core *yggdrasil.Core
|
||||
config *yggdrasilconfig.NodeConfig
|
||||
state *yggdrasilconfig.NodeState
|
||||
multicast *yggdrasilmulticast.Multicast
|
||||
log *gologme.Logger
|
||||
listener quic.Listener
|
||||
tlsConfig *tls.Config
|
||||
quicConfig *quic.Config
|
||||
sessions sync.Map // string -> *session
|
||||
sessionCount atomic.Uint32
|
||||
sessionFunc func(address string)
|
||||
coords sync.Map // string -> yggdrasil.Coords
|
||||
incoming chan QUICStream
|
||||
NewSession func(remote gomatrixserverlib.ServerName)
|
||||
core *yggdrasilcore.Core
|
||||
config *yggdrasilconfig.NodeConfig
|
||||
multicast *yggdrasilmulticast.Multicast
|
||||
log *gologme.Logger
|
||||
listener quic.Listener
|
||||
utpSocket *utp.Socket
|
||||
incoming chan net.Conn
|
||||
}
|
||||
|
||||
func (n *Node) Dialer(_, address string) (net.Conn, error) {
|
||||
func (n *Node) DialerContext(ctx context.Context, _, address string) (net.Conn, error) {
|
||||
tokens := strings.Split(address, ":")
|
||||
raw, err := hex.DecodeString(tokens[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hex.DecodeString: %w", err)
|
||||
}
|
||||
converted := convert.Ed25519PublicKeyToCurve25519(ed25519.PublicKey(raw))
|
||||
convhex := hex.EncodeToString(converted)
|
||||
return n.Dial("curve25519", convhex)
|
||||
pk := make(ironwoodtypes.Addr, ed25519.PublicKeySize)
|
||||
copy(pk, raw[:])
|
||||
return n.utpSocket.DialAddrContext(ctx, pk)
|
||||
}
|
||||
|
||||
func (n *Node) DialerContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return n.Dialer(network, address)
|
||||
}
|
||||
|
||||
func Setup(instanceName, storageDirectory string) (*Node, error) {
|
||||
func Setup(instanceName, storageDirectory, peerURI string) (*Node, error) {
|
||||
n := &Node{
|
||||
core: &yggdrasil.Core{},
|
||||
config: yggdrasilconfig.GenerateConfig(),
|
||||
core: &yggdrasilcore.Core{},
|
||||
config: yggdrasildefaults.GenerateConfig(),
|
||||
multicast: &yggdrasilmulticast.Multicast{},
|
||||
log: gologme.New(os.Stdout, "YGG ", log.Flags()),
|
||||
incoming: make(chan QUICStream),
|
||||
incoming: make(chan net.Conn),
|
||||
}
|
||||
|
||||
yggfile := fmt.Sprintf("%s/%s-yggdrasil.conf", storageDirectory, instanceName)
|
||||
|
@ -93,24 +80,11 @@ func Setup(instanceName, storageDirectory string) (*Node, error) {
|
|||
}
|
||||
}
|
||||
|
||||
n.core.SetCoordChangeCallback(func(old, new yggdrasil.Coords) {
|
||||
fmt.Println("COORDINATE CHANGE!")
|
||||
fmt.Println("Old:", old)
|
||||
fmt.Println("New:", new)
|
||||
n.sessions.Range(func(k, v interface{}) bool {
|
||||
if s, ok := v.(*session); ok {
|
||||
fmt.Println("Killing session", k)
|
||||
s.kill()
|
||||
}
|
||||
return true
|
||||
})
|
||||
})
|
||||
|
||||
n.config.Peers = []string{}
|
||||
if peerURI != "" {
|
||||
n.config.Peers = append(n.config.Peers, peerURI)
|
||||
}
|
||||
n.config.AdminListen = "none"
|
||||
n.config.MulticastInterfaces = []string{}
|
||||
n.config.EncryptionPrivateKey = hex.EncodeToString(n.EncryptionPrivateKey())
|
||||
n.config.EncryptionPublicKey = hex.EncodeToString(n.EncryptionPublicKey())
|
||||
|
||||
j, err := json.MarshalIndent(n.config, "", " ")
|
||||
if err != nil {
|
||||
|
@ -123,34 +97,22 @@ func Setup(instanceName, storageDirectory string) (*Node, error) {
|
|||
n.log.EnableLevel("error")
|
||||
n.log.EnableLevel("warn")
|
||||
n.log.EnableLevel("info")
|
||||
n.state, err = n.core.Start(n.config, n.log)
|
||||
if err = n.core.Start(n.config, n.log); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
n.utpSocket, err = utp.NewSocketFromPacketConnNoClose(n.core)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err = n.multicast.Init(n.core, n.state, n.log, nil); err != nil {
|
||||
if err = n.multicast.Init(n.core, n.config, n.log, nil); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err = n.multicast.Start(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
n.tlsConfig = n.generateTLSConfig()
|
||||
n.quicConfig = &quic.Config{
|
||||
MaxIncomingStreams: 0,
|
||||
MaxIncomingUniStreams: 0,
|
||||
KeepAlive: true,
|
||||
MaxIdleTimeout: time.Minute * 30,
|
||||
HandshakeTimeout: time.Second * 15,
|
||||
}
|
||||
copy(n.quicConfig.StatelessResetKey, n.EncryptionPublicKey())
|
||||
|
||||
n.log.Println("Public curve25519:", n.core.EncryptionPublicKey())
|
||||
n.log.Println("Public ed25519:", n.core.SigningPublicKey())
|
||||
|
||||
go func() {
|
||||
time.Sleep(time.Second)
|
||||
n.listenFromYgg()
|
||||
}()
|
||||
n.log.Println("Public key:", n.core.PublicKey())
|
||||
go n.listenFromYgg()
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
@ -163,64 +125,33 @@ func (n *Node) Stop() {
|
|||
}
|
||||
|
||||
func (n *Node) DerivedServerName() string {
|
||||
return hex.EncodeToString(n.SigningPublicKey())
|
||||
return hex.EncodeToString(n.PublicKey())
|
||||
}
|
||||
|
||||
func (n *Node) DerivedSessionName() string {
|
||||
return hex.EncodeToString(n.EncryptionPublicKey())
|
||||
func (n *Node) PrivateKey() ed25519.PrivateKey {
|
||||
sk := make(ed25519.PrivateKey, ed25519.PrivateKeySize)
|
||||
sb, err := hex.DecodeString(n.config.PrivateKey)
|
||||
if err == nil {
|
||||
copy(sk, sb[:])
|
||||
} else {
|
||||
panic(err)
|
||||
}
|
||||
return sk
|
||||
}
|
||||
|
||||
func (n *Node) EncryptionPublicKey() []byte {
|
||||
edkey := n.SigningPublicKey()
|
||||
return convert.Ed25519PublicKeyToCurve25519(edkey)
|
||||
}
|
||||
|
||||
func (n *Node) EncryptionPrivateKey() []byte {
|
||||
edkey := n.SigningPrivateKey()
|
||||
return convert.Ed25519PrivateKeyToCurve25519(edkey)
|
||||
}
|
||||
|
||||
func (n *Node) SigningPublicKey() ed25519.PublicKey {
|
||||
pubBytes, _ := hex.DecodeString(n.config.SigningPublicKey)
|
||||
return ed25519.PublicKey(pubBytes)
|
||||
}
|
||||
|
||||
func (n *Node) SigningPrivateKey() ed25519.PrivateKey {
|
||||
privBytes, _ := hex.DecodeString(n.config.SigningPrivateKey)
|
||||
return ed25519.PrivateKey(privBytes)
|
||||
}
|
||||
|
||||
func (n *Node) SetSessionFunc(f func(address string)) {
|
||||
n.sessionFunc = f
|
||||
func (n *Node) PublicKey() ed25519.PublicKey {
|
||||
return n.core.PublicKey()
|
||||
}
|
||||
|
||||
func (n *Node) PeerCount() int {
|
||||
return len(n.core.GetPeers()) - 1
|
||||
}
|
||||
|
||||
func (n *Node) SessionCount() int {
|
||||
return int(n.sessionCount.Load())
|
||||
return len(n.core.GetPeers())
|
||||
}
|
||||
|
||||
func (n *Node) KnownNodes() []gomatrixserverlib.ServerName {
|
||||
nodemap := map[string]struct{}{
|
||||
//"b5ae50589e50991dd9dd7d59c5c5f7a4521e8da5b603b7f57076272abc58b374": {},
|
||||
nodemap := map[string]struct{}{}
|
||||
for _, peer := range n.core.GetPeers() {
|
||||
nodemap[hex.EncodeToString(peer.Key)] = struct{}{}
|
||||
}
|
||||
for _, peer := range n.core.GetSwitchPeers() {
|
||||
nodemap[hex.EncodeToString(peer.SigPublicKey[:])] = struct{}{}
|
||||
}
|
||||
n.sessions.Range(func(_, v interface{}) bool {
|
||||
session, ok := v.(quic.Session)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
if len(session.ConnectionState().PeerCertificates) != 1 {
|
||||
return true
|
||||
}
|
||||
subjectName := session.ConnectionState().PeerCertificates[0].Subject.CommonName
|
||||
nodemap[subjectName] = struct{}{}
|
||||
return true
|
||||
})
|
||||
var nodes []gomatrixserverlib.ServerName
|
||||
for node := range nodemap {
|
||||
nodes = append(nodes, gomatrixserverlib.ServerName(node))
|
||||
|
@ -229,53 +160,22 @@ func (n *Node) KnownNodes() []gomatrixserverlib.ServerName {
|
|||
}
|
||||
|
||||
func (n *Node) SetMulticastEnabled(enabled bool) {
|
||||
if enabled {
|
||||
n.config.MulticastInterfaces = []string{".*"}
|
||||
} else {
|
||||
n.config.MulticastInterfaces = []string{}
|
||||
}
|
||||
n.multicast.UpdateConfig(n.config)
|
||||
if !enabled {
|
||||
n.DisconnectMulticastPeers()
|
||||
}
|
||||
// TODO: There's no dynamic reconfiguration in Yggdrasil v0.4
|
||||
// so we need a solution for this.
|
||||
}
|
||||
|
||||
func (n *Node) DisconnectMulticastPeers() {
|
||||
for _, sp := range n.core.GetSwitchPeers() {
|
||||
if !strings.HasPrefix(sp.Endpoint, "fe80") {
|
||||
continue
|
||||
}
|
||||
if err := n.core.DisconnectPeer(sp.Port); err != nil {
|
||||
n.log.Printf("Failed to disconnect port %d: %s", sp.Port, err)
|
||||
}
|
||||
}
|
||||
// TODO: There's no dynamic reconfiguration in Yggdrasil v0.4
|
||||
// so we need a solution for this.
|
||||
}
|
||||
|
||||
func (n *Node) DisconnectNonMulticastPeers() {
|
||||
for _, sp := range n.core.GetSwitchPeers() {
|
||||
if strings.HasPrefix(sp.Endpoint, "fe80") {
|
||||
continue
|
||||
}
|
||||
if err := n.core.DisconnectPeer(sp.Port); err != nil {
|
||||
n.log.Printf("Failed to disconnect port %d: %s", sp.Port, err)
|
||||
}
|
||||
}
|
||||
// TODO: There's no dynamic reconfiguration in Yggdrasil v0.4
|
||||
// so we need a solution for this.
|
||||
}
|
||||
|
||||
func (n *Node) SetStaticPeer(uri string) error {
|
||||
n.config.Peers = []string{}
|
||||
n.core.UpdateConfig(n.config)
|
||||
n.DisconnectNonMulticastPeers()
|
||||
if uri != "" {
|
||||
n.log.Infoln("Adding static peer", uri)
|
||||
if err := n.core.AddPeer(uri, ""); err != nil {
|
||||
n.log.Warnln("Adding static peer failed:", err)
|
||||
return err
|
||||
}
|
||||
if err := n.core.CallPeer(uri, ""); err != nil {
|
||||
n.log.Warnln("Calling static peer failed:", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
// TODO: There's no dynamic reconfiguration in Yggdrasil v0.4
|
||||
// so we need a solution for this.
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -16,94 +16,17 @@ package yggconn
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/yggdrasil-network/yggdrasil-go/src/crypto"
|
||||
"github.com/yggdrasil-network/yggdrasil-go/src/yggdrasil"
|
||||
)
|
||||
|
||||
type session struct {
|
||||
node *Node
|
||||
session quic.Session
|
||||
address string
|
||||
context context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (n *Node) newSession(sess quic.Session, address string) *session {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
return &session{
|
||||
node: n,
|
||||
session: sess,
|
||||
address: address,
|
||||
context: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) kill() {
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
func (n *Node) listenFromYgg() {
|
||||
var err error
|
||||
n.listener, err = quic.Listen(
|
||||
n.core, // yggdrasil.PacketConn
|
||||
n.tlsConfig, // TLS config
|
||||
n.quicConfig, // QUIC config
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for {
|
||||
n.log.Infoln("Waiting to accept QUIC sessions")
|
||||
session, err := n.listener.Accept(context.TODO())
|
||||
conn, err := n.utpSocket.Accept()
|
||||
if err != nil {
|
||||
n.log.Println("n.listener.Accept:", err)
|
||||
n.log.Println("n.utpSocket.Accept:", err)
|
||||
return
|
||||
}
|
||||
if len(session.ConnectionState().PeerCertificates) != 1 {
|
||||
_ = session.CloseWithError(0, "expected a peer certificate")
|
||||
continue
|
||||
}
|
||||
address := session.ConnectionState().PeerCertificates[0].DNSNames[0]
|
||||
n.log.Infoln("Accepted connection from", address)
|
||||
go n.newSession(session, address).listenFromQUIC()
|
||||
go n.sessionFunc(address)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) listenFromQUIC() {
|
||||
if existing, ok := s.node.sessions.Load(s.address); ok {
|
||||
if existingSession, ok := existing.(*session); ok {
|
||||
fmt.Println("Killing existing session to replace", s.address)
|
||||
existingSession.kill()
|
||||
}
|
||||
}
|
||||
s.node.sessionCount.Inc()
|
||||
s.node.sessions.Store(s.address, s)
|
||||
defer s.node.sessions.Delete(s.address)
|
||||
defer s.node.sessionCount.Dec()
|
||||
for {
|
||||
st, err := s.session.AcceptStream(s.context)
|
||||
if err != nil {
|
||||
s.node.log.Println("session.AcceptStream:", err)
|
||||
return
|
||||
}
|
||||
s.node.incoming <- QUICStream{st, s.session}
|
||||
n.incoming <- conn
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -129,155 +52,5 @@ func (n *Node) Dial(network, address string) (net.Conn, error) {
|
|||
|
||||
// Implements http.Transport.DialContext
|
||||
func (n *Node) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
s, ok1 := n.sessions.Load(address)
|
||||
session, ok2 := s.(*session)
|
||||
if !ok1 || !ok2 {
|
||||
// First of all, check if we think we know the coords of this
|
||||
// node. If we do then we'll try to dial to it directly. This
|
||||
// will either succeed or fail.
|
||||
if v, ok := n.coords.Load(address); ok {
|
||||
coords, ok := v.(yggdrasil.Coords)
|
||||
if !ok {
|
||||
n.coords.Delete(address)
|
||||
return nil, errors.New("should have found yggdrasil.Coords but didn't")
|
||||
}
|
||||
n.log.Infof("Coords %s for %q cached, trying to dial", coords.String(), address)
|
||||
var err error
|
||||
// We think we know the coords. Try to dial the node.
|
||||
if session, err = n.tryDial(address, coords); err != nil {
|
||||
// We thought we knew the coords but it didn't result
|
||||
// in a successful dial. Nuke them from the cache.
|
||||
n.coords.Delete(address)
|
||||
n.log.Infof("Cached coords %s for %q failed", coords.String(), address)
|
||||
}
|
||||
}
|
||||
|
||||
// We either don't know the coords for the node, or we failed
|
||||
// to dial it before, in which case try to resolve the coords.
|
||||
if _, ok := n.coords.Load(address); !ok {
|
||||
var coords yggdrasil.Coords
|
||||
var err error
|
||||
|
||||
// First look and see if the node is something that we already
|
||||
// know about from our direct switch peers.
|
||||
for _, peer := range n.core.GetSwitchPeers() {
|
||||
if peer.PublicKey.String() == address {
|
||||
coords = peer.Coords
|
||||
n.log.Infof("%q is a direct peer, coords are %s", address, coords.String())
|
||||
n.coords.Store(address, coords)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If it isn' a node that we know directly then try to search
|
||||
// the network.
|
||||
if coords == nil {
|
||||
n.log.Infof("Searching for coords for %q", address)
|
||||
dest, derr := hex.DecodeString(address)
|
||||
if derr != nil {
|
||||
return nil, derr
|
||||
}
|
||||
if len(dest) != crypto.BoxPubKeyLen {
|
||||
return nil, errors.New("invalid key length supplied")
|
||||
}
|
||||
var pubKey crypto.BoxPubKey
|
||||
copy(pubKey[:], dest)
|
||||
nodeID := crypto.GetNodeID(&pubKey)
|
||||
nodeMask := &crypto.NodeID{}
|
||||
for i := range nodeMask {
|
||||
nodeMask[i] = 0xFF
|
||||
}
|
||||
|
||||
fmt.Println("Resolving coords")
|
||||
coords, err = n.core.Resolve(nodeID, nodeMask)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("n.core.Resolve: %w", err)
|
||||
}
|
||||
fmt.Println("Found coords:", coords)
|
||||
n.coords.Store(address, coords)
|
||||
}
|
||||
|
||||
// We now know the coords in theory. Let's try dialling the
|
||||
// node again.
|
||||
if session, err = n.tryDial(address, coords); err != nil {
|
||||
return nil, fmt.Errorf("n.tryDial: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if session == nil {
|
||||
return nil, fmt.Errorf("should have found session but didn't")
|
||||
}
|
||||
|
||||
st, err := session.session.OpenStream()
|
||||
if err != nil {
|
||||
n.log.Println("session.OpenStream:", err)
|
||||
_ = session.session.CloseWithError(0, "expected to be able to open session")
|
||||
return nil, err
|
||||
}
|
||||
return QUICStream{st, session.session}, nil
|
||||
}
|
||||
|
||||
func (n *Node) tryDial(address string, coords yggdrasil.Coords) (*session, error) {
|
||||
quicSession, err := quic.Dial(
|
||||
n.core, // yggdrasil.PacketConn
|
||||
coords, // dial address
|
||||
address, // dial SNI
|
||||
n.tlsConfig, // TLS config
|
||||
n.quicConfig, // QUIC config
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(quicSession.ConnectionState().PeerCertificates) != 1 {
|
||||
_ = quicSession.CloseWithError(0, "expected a peer certificate")
|
||||
return nil, errors.New("didn't receive a peer certificate")
|
||||
}
|
||||
if len(quicSession.ConnectionState().PeerCertificates[0].DNSNames) != 1 {
|
||||
_ = quicSession.CloseWithError(0, "expected a DNS name")
|
||||
return nil, errors.New("didn't receive a DNS name")
|
||||
}
|
||||
if gotAddress := quicSession.ConnectionState().PeerCertificates[0].DNSNames[0]; address != gotAddress {
|
||||
_ = quicSession.CloseWithError(0, "you aren't the host I was hoping for")
|
||||
return nil, fmt.Errorf("expected %q but dialled %q", address, gotAddress)
|
||||
}
|
||||
session := n.newSession(quicSession, address)
|
||||
go session.listenFromQUIC()
|
||||
go n.sessionFunc(address)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (n *Node) generateTLSConfig() *tls.Config {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
template := x509.Certificate{
|
||||
Subject: pkix.Name{
|
||||
CommonName: n.DerivedServerName(),
|
||||
},
|
||||
SerialNumber: big.NewInt(1),
|
||||
NotAfter: time.Now().Add(time.Hour * 24 * 365),
|
||||
DNSNames: []string{n.DerivedSessionName()},
|
||||
}
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||
|
||||
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{tlsCert},
|
||||
NextProtos: []string{"quic-matrix-ygg"},
|
||||
InsecureSkipVerify: true,
|
||||
ClientAuth: tls.RequireAnyClientCert,
|
||||
GetClientCertificate: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return &tlsCert, nil
|
||||
},
|
||||
}
|
||||
return n.utpSocket.DialContext(ctx, network, address)
|
||||
}
|
||||
|
|
|
@ -1,20 +0,0 @@
|
|||
package yggconn
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
type QUICStream struct {
|
||||
quic.Stream
|
||||
session quic.Session
|
||||
}
|
||||
|
||||
func (s QUICStream) LocalAddr() net.Addr {
|
||||
return s.session.LocalAddr()
|
||||
}
|
||||
|
||||
func (s QUICStream) RemoteAddr() net.Addr {
|
||||
return s.session.RemoteAddr()
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue