mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-27 23:48:27 +00:00
Merge branch 'master' into matthew/peeking
This commit is contained in:
commit
8712ea337a
79 changed files with 4073 additions and 1928 deletions
9
build.sh
9
build.sh
|
@ -3,6 +3,11 @@
|
||||||
# Put installed packages into ./bin
|
# Put installed packages into ./bin
|
||||||
export GOBIN=$PWD/`dirname $0`/bin
|
export GOBIN=$PWD/`dirname $0`/bin
|
||||||
|
|
||||||
go install -v $PWD/`dirname $0`/cmd/...
|
export BRANCH=`(git symbolic-ref --short HEAD | cut -d'/' -f 3 )|| ""`
|
||||||
|
export BUILD=`git rev-parse --short HEAD || ""`
|
||||||
|
|
||||||
GOOS=js GOARCH=wasm go build -o main.wasm ./cmd/dendritejs
|
export FLAGS="-X github.com/matrix-org/dendrite/internal.branch=$BRANCH -X github.com/matrix-org/dendrite/internal.build=$BUILD"
|
||||||
|
|
||||||
|
go install -trimpath -ldflags "$FLAGS" -v $PWD/`dirname $0`/cmd/...
|
||||||
|
|
||||||
|
GOOS=js GOARCH=wasm go build -trimpath -ldflags "$FLAGS" -o main.wasm ./cmd/dendritejs
|
||||||
|
|
|
@ -120,7 +120,7 @@ func (m *DendriteMonolith) Start() {
|
||||||
keyAPI.SetUserAPI(userAPI)
|
keyAPI.SetUserAPI(userAPI)
|
||||||
|
|
||||||
rsAPI := roomserver.NewInternalAPI(
|
rsAPI := roomserver.NewInternalAPI(
|
||||||
base, keyRing, federation,
|
base, keyRing,
|
||||||
)
|
)
|
||||||
|
|
||||||
eduInputAPI := eduserver.NewInternalAPI(
|
eduInputAPI := eduserver.NewInternalAPI(
|
||||||
|
|
|
@ -12,8 +12,7 @@ COPY . .
|
||||||
RUN go build ./cmd/dendrite-monolith-server
|
RUN go build ./cmd/dendrite-monolith-server
|
||||||
RUN go build ./cmd/generate-keys
|
RUN go build ./cmd/generate-keys
|
||||||
RUN go build ./cmd/generate-config
|
RUN go build ./cmd/generate-config
|
||||||
RUN ./generate-config > dendrite.yaml
|
RUN ./generate-config --ci > dendrite.yaml
|
||||||
RUN sed -i "s/disable_tls_validation: false/disable_tls_validation: true/g" dendrite.yaml
|
|
||||||
RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key
|
RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key
|
||||||
|
|
||||||
ENV SERVER_NAME=localhost
|
ENV SERVER_NAME=localhost
|
||||||
|
|
|
@ -10,10 +10,10 @@ cd `dirname $0`/../..
|
||||||
docker build -t complement-dendrite -f build/scripts/Complement.Dockerfile .
|
docker build -t complement-dendrite -f build/scripts/Complement.Dockerfile .
|
||||||
|
|
||||||
# Download Complement
|
# Download Complement
|
||||||
wget https://github.com/matrix-org/complement/archive/master.tar.gz
|
wget -N https://github.com/matrix-org/complement/archive/master.tar.gz
|
||||||
tar -xzf master.tar.gz
|
tar -xzf master.tar.gz
|
||||||
|
|
||||||
# Run the tests!
|
# Run the tests!
|
||||||
cd complement-master
|
cd complement-master
|
||||||
COMPLEMENT_BASE_IMAGE=complement-dendrite:latest go test -v ./tests
|
COMPLEMENT_BASE_IMAGE=complement-dendrite:latest go test -v -count=1 ./tests
|
||||||
|
|
||||||
|
|
|
@ -342,8 +342,7 @@ func createRoom(
|
||||||
}
|
}
|
||||||
|
|
||||||
// send events to the room server
|
// send events to the room server
|
||||||
_, err = roomserverAPI.SendEvents(req.Context(), rsAPI, builtEvents, cfg.Matrix.ServerName, nil)
|
if err = roomserverAPI.SendEvents(req.Context(), rsAPI, builtEvents, cfg.Matrix.ServerName, nil); err != nil {
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
|
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,15 +41,13 @@ type flows struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type flow struct {
|
type flow struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Stages []string `json:"stages"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func passwordLogin() flows {
|
func passwordLogin() flows {
|
||||||
f := flows{}
|
f := flows{}
|
||||||
s := flow{
|
s := flow{
|
||||||
Type: "m.login.password",
|
Type: "m.login.password",
|
||||||
Stages: []string{"m.login.password"},
|
|
||||||
}
|
}
|
||||||
f.Flows = append(f.Flows, s)
|
f.Flows = append(f.Flows, s)
|
||||||
return f
|
return f
|
||||||
|
|
|
@ -75,13 +75,12 @@ func sendMembership(ctx context.Context, accountDB accounts.Database, device *us
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = roomserverAPI.SendEvents(
|
if err = roomserverAPI.SendEvents(
|
||||||
ctx, rsAPI,
|
ctx, rsAPI,
|
||||||
[]gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)},
|
[]gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)},
|
||||||
cfg.Matrix.ServerName,
|
cfg.Matrix.ServerName,
|
||||||
nil,
|
nil,
|
||||||
)
|
); err != nil {
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(ctx).WithError(err).Error("SendEvents failed")
|
util.GetLogger(ctx).WithError(err).Error("SendEvents failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
@ -270,7 +269,7 @@ func buildMembershipEvent(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return eventutil.BuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil)
|
return eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadProfile lookups the profile of a given user from the database and returns
|
// loadProfile lookups the profile of a given user from the database and returns
|
||||||
|
|
|
@ -171,7 +171,7 @@ func SetAvatarURL(
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := api.SendEvents(req.Context(), rsAPI, events, cfg.Matrix.ServerName, nil); err != nil {
|
if err := api.SendEvents(req.Context(), rsAPI, events, cfg.Matrix.ServerName, nil); err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
|
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
@ -289,7 +289,7 @@ func SetDisplayName(
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := api.SendEvents(req.Context(), rsAPI, events, cfg.Matrix.ServerName, nil); err != nil {
|
if err := api.SendEvents(req.Context(), rsAPI, events, cfg.Matrix.ServerName, nil); err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
|
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
@ -375,7 +375,7 @@ func buildMembershipEvents(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
event, err := eventutil.BuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil)
|
event, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
99
clientapi/routing/rate_limiting.go
Normal file
99
clientapi/routing/rate_limiting.go
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
type rateLimits struct {
|
||||||
|
limits map[string]chan struct{}
|
||||||
|
limitsMutex sync.RWMutex
|
||||||
|
enabled bool
|
||||||
|
requestThreshold int64
|
||||||
|
cooloffDuration time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRateLimits(cfg *config.RateLimiting) *rateLimits {
|
||||||
|
l := &rateLimits{
|
||||||
|
limits: make(map[string]chan struct{}),
|
||||||
|
enabled: cfg.Enabled,
|
||||||
|
requestThreshold: cfg.Threshold,
|
||||||
|
cooloffDuration: time.Duration(cfg.CooloffMS) * time.Millisecond,
|
||||||
|
}
|
||||||
|
if l.enabled {
|
||||||
|
go l.clean()
|
||||||
|
}
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *rateLimits) clean() {
|
||||||
|
for {
|
||||||
|
// On a 30 second interval, we'll take an exclusive write
|
||||||
|
// lock of the entire map and see if any of the channels are
|
||||||
|
// empty. If they are then we will close and delete them,
|
||||||
|
// freeing up memory.
|
||||||
|
time.Sleep(time.Second * 30)
|
||||||
|
l.limitsMutex.Lock()
|
||||||
|
for k, c := range l.limits {
|
||||||
|
if len(c) == 0 {
|
||||||
|
close(c)
|
||||||
|
delete(l.limits, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
l.limitsMutex.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *rateLimits) rateLimit(req *http.Request) *util.JSONResponse {
|
||||||
|
// If rate limiting is disabled then do nothing.
|
||||||
|
if !l.enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lock the map long enough to check for rate limiting. We hold it
|
||||||
|
// for longer here than we really need to but it makes sure that we
|
||||||
|
// also don't conflict with the cleaner goroutine which might clean
|
||||||
|
// up a channel after we have retrieved it otherwise.
|
||||||
|
l.limitsMutex.RLock()
|
||||||
|
defer l.limitsMutex.RUnlock()
|
||||||
|
|
||||||
|
// First of all, work out if X-Forwarded-For was sent to us. If not
|
||||||
|
// then we'll just use the IP address of the caller.
|
||||||
|
caller := req.RemoteAddr
|
||||||
|
if forwardedFor := req.Header.Get("X-Forwarded-For"); forwardedFor != "" {
|
||||||
|
caller = forwardedFor
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up the caller's channel, if they have one. If they don't then
|
||||||
|
// let's create one.
|
||||||
|
rateLimit, ok := l.limits[caller]
|
||||||
|
if !ok {
|
||||||
|
l.limits[caller] = make(chan struct{}, l.requestThreshold)
|
||||||
|
rateLimit = l.limits[caller]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the user has got free resource slots for this request.
|
||||||
|
// If they don't then we'll return an error.
|
||||||
|
select {
|
||||||
|
case rateLimit <- struct{}{}:
|
||||||
|
default:
|
||||||
|
// We hit the rate limit. Tell the client to back off.
|
||||||
|
return &util.JSONResponse{
|
||||||
|
Code: http.StatusTooManyRequests,
|
||||||
|
JSON: jsonerror.LimitExceeded("You are sending too many requests too quickly!", l.cooloffDuration.Milliseconds()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// After the time interval, drain a resource from the rate limiting
|
||||||
|
// channel. This will free up space in the channel for new requests.
|
||||||
|
go func() {
|
||||||
|
<-time.After(l.cooloffDuration)
|
||||||
|
<-rateLimit
|
||||||
|
}()
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -115,15 +115,14 @@ func SendRedaction(
|
||||||
}
|
}
|
||||||
|
|
||||||
var queryRes api.QueryLatestEventsAndStateResponse
|
var queryRes api.QueryLatestEventsAndStateResponse
|
||||||
e, err := eventutil.BuildEvent(req.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes)
|
e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes)
|
||||||
if err == eventutil.ErrRoomNoExists {
|
if err == eventutil.ErrRoomNoExists {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusNotFound,
|
Code: http.StatusNotFound,
|
||||||
JSON: jsonerror.NotFound("Room does not exist"),
|
JSON: jsonerror.NotFound("Room does not exist"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_, err = roomserverAPI.SendEvents(context.Background(), rsAPI, []gomatrixserverlib.HeaderedEvent{*e}, cfg.Matrix.ServerName, nil)
|
if err = roomserverAPI.SendEvents(context.Background(), rsAPI, []gomatrixserverlib.HeaderedEvent{*e}, cfg.Matrix.ServerName, nil); err != nil {
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents")
|
util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,6 +60,7 @@ func Setup(
|
||||||
keyAPI keyserverAPI.KeyInternalAPI,
|
keyAPI keyserverAPI.KeyInternalAPI,
|
||||||
extRoomsProvider api.ExtraPublicRoomsProvider,
|
extRoomsProvider api.ExtraPublicRoomsProvider,
|
||||||
) {
|
) {
|
||||||
|
rateLimits := newRateLimits(&cfg.RateLimiting)
|
||||||
userInteractiveAuth := auth.NewUserInteractive(accountDB.GetAccountByPassword, cfg)
|
userInteractiveAuth := auth.NewUserInteractive(accountDB.GetAccountByPassword, cfg)
|
||||||
|
|
||||||
publicAPIMux.Handle("/versions",
|
publicAPIMux.Handle("/versions",
|
||||||
|
@ -92,6 +93,9 @@ func Setup(
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
r0mux.Handle("/join/{roomIDOrAlias}",
|
r0mux.Handle("/join/{roomIDOrAlias}",
|
||||||
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
if r := rateLimits.rateLimit(req); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
|
@ -119,6 +123,9 @@ func Setup(
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
r0mux.Handle("/rooms/{roomID}/join",
|
r0mux.Handle("/rooms/{roomID}/join",
|
||||||
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
if r := rateLimits.rateLimit(req); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
|
@ -130,6 +137,9 @@ func Setup(
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
r0mux.Handle("/rooms/{roomID}/leave",
|
r0mux.Handle("/rooms/{roomID}/leave",
|
||||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
if r := rateLimits.rateLimit(req); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
|
@ -150,6 +160,9 @@ func Setup(
|
||||||
).Methods(http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodPost, http.MethodOptions)
|
||||||
r0mux.Handle("/rooms/{roomID}/invite",
|
r0mux.Handle("/rooms/{roomID}/invite",
|
||||||
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
if r := rateLimits.rateLimit(req); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
|
@ -264,14 +277,23 @@ func Setup(
|
||||||
).Methods(http.MethodPut, http.MethodOptions)
|
).Methods(http.MethodPut, http.MethodOptions)
|
||||||
|
|
||||||
r0mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
|
r0mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
|
||||||
|
if r := rateLimits.rateLimit(req); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
return Register(req, userAPI, accountDB, cfg)
|
return Register(req, userAPI, accountDB, cfg)
|
||||||
})).Methods(http.MethodPost, http.MethodOptions)
|
})).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
v1mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
|
v1mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
|
||||||
|
if r := rateLimits.rateLimit(req); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
return LegacyRegister(req, userAPI, cfg)
|
return LegacyRegister(req, userAPI, cfg)
|
||||||
})).Methods(http.MethodPost, http.MethodOptions)
|
})).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
|
||||||
r0mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse {
|
r0mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse {
|
||||||
|
if r := rateLimits.rateLimit(req); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
return RegisterAvailable(req, cfg, accountDB)
|
return RegisterAvailable(req, cfg, accountDB)
|
||||||
})).Methods(http.MethodGet, http.MethodOptions)
|
})).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
|
||||||
|
@ -343,6 +365,9 @@ func Setup(
|
||||||
|
|
||||||
r0mux.Handle("/rooms/{roomID}/typing/{userID}",
|
r0mux.Handle("/rooms/{roomID}/typing/{userID}",
|
||||||
httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("rooms_typing", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
if r := rateLimits.rateLimit(req); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
|
@ -396,6 +421,9 @@ func Setup(
|
||||||
|
|
||||||
r0mux.Handle("/account/whoami",
|
r0mux.Handle("/account/whoami",
|
||||||
httputil.MakeAuthAPI("whoami", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("whoami", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
if r := rateLimits.rateLimit(req); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
return Whoami(req, device)
|
return Whoami(req, device)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
@ -404,6 +432,9 @@ func Setup(
|
||||||
|
|
||||||
r0mux.Handle("/login",
|
r0mux.Handle("/login",
|
||||||
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
|
httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
|
||||||
|
if r := rateLimits.rateLimit(req); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
return Login(req, accountDB, userAPI, cfg)
|
return Login(req, accountDB, userAPI, cfg)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
|
||||||
|
@ -458,6 +489,9 @@ func Setup(
|
||||||
|
|
||||||
r0mux.Handle("/profile/{userID}/avatar_url",
|
r0mux.Handle("/profile/{userID}/avatar_url",
|
||||||
httputil.MakeAuthAPI("profile_avatar_url", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("profile_avatar_url", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
if r := rateLimits.rateLimit(req); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
|
@ -480,6 +514,9 @@ func Setup(
|
||||||
|
|
||||||
r0mux.Handle("/profile/{userID}/displayname",
|
r0mux.Handle("/profile/{userID}/displayname",
|
||||||
httputil.MakeAuthAPI("profile_displayname", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("profile_displayname", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
if r := rateLimits.rateLimit(req); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
|
@ -517,6 +554,9 @@ func Setup(
|
||||||
// Riot logs get flooded unless this is handled
|
// Riot logs get flooded unless this is handled
|
||||||
r0mux.Handle("/presence/{userID}/status",
|
r0mux.Handle("/presence/{userID}/status",
|
||||||
httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse {
|
httputil.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse {
|
||||||
|
if r := rateLimits.rateLimit(req); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
// TODO: Set presence (probably the responsibility of a presence server not clientapi)
|
// TODO: Set presence (probably the responsibility of a presence server not clientapi)
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
|
@ -527,6 +567,9 @@ func Setup(
|
||||||
|
|
||||||
r0mux.Handle("/voip/turnServer",
|
r0mux.Handle("/voip/turnServer",
|
||||||
httputil.MakeAuthAPI("turn_server", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("turn_server", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
if r := rateLimits.rateLimit(req); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
return RequestTurnServer(req, device, cfg)
|
return RequestTurnServer(req, device, cfg)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet, http.MethodOptions)
|
).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
@ -593,6 +636,9 @@ func Setup(
|
||||||
|
|
||||||
r0mux.Handle("/user_directory/search",
|
r0mux.Handle("/user_directory/search",
|
||||||
httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("userdirectory_search", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
if r := rateLimits.rateLimit(req); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
postContent := struct {
|
postContent := struct {
|
||||||
SearchString string `json:"search_term"`
|
SearchString string `json:"search_term"`
|
||||||
Limit int `json:"limit"`
|
Limit int `json:"limit"`
|
||||||
|
@ -634,6 +680,9 @@ func Setup(
|
||||||
|
|
||||||
r0mux.Handle("/rooms/{roomID}/read_markers",
|
r0mux.Handle("/rooms/{roomID}/read_markers",
|
||||||
httputil.MakeExternalAPI("rooms_read_markers", func(req *http.Request) util.JSONResponse {
|
httputil.MakeExternalAPI("rooms_read_markers", func(req *http.Request) util.JSONResponse {
|
||||||
|
if r := rateLimits.rateLimit(req); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
// TODO: return the read_markers.
|
// TODO: return the read_markers.
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
|
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
|
||||||
}),
|
}),
|
||||||
|
@ -732,6 +781,9 @@ func Setup(
|
||||||
|
|
||||||
r0mux.Handle("/capabilities",
|
r0mux.Handle("/capabilities",
|
||||||
httputil.MakeAuthAPI("capabilities", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
httputil.MakeAuthAPI("capabilities", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||||
|
if r := rateLimits.rateLimit(req); r != nil {
|
||||||
|
return *r
|
||||||
|
}
|
||||||
return GetCapabilities(req, rsAPI)
|
return GetCapabilities(req, rsAPI)
|
||||||
}),
|
}),
|
||||||
).Methods(http.MethodGet)
|
).Methods(http.MethodGet)
|
||||||
|
|
|
@ -90,27 +90,26 @@ func SendEvent(
|
||||||
|
|
||||||
// pass the new event to the roomserver and receive the correct event ID
|
// pass the new event to the roomserver and receive the correct event ID
|
||||||
// event ID in case of duplicate transaction is discarded
|
// event ID in case of duplicate transaction is discarded
|
||||||
eventID, err := api.SendEvents(
|
if err := api.SendEvents(
|
||||||
req.Context(), rsAPI,
|
req.Context(), rsAPI,
|
||||||
[]gomatrixserverlib.HeaderedEvent{
|
[]gomatrixserverlib.HeaderedEvent{
|
||||||
e.Headered(verRes.RoomVersion),
|
e.Headered(verRes.RoomVersion),
|
||||||
},
|
},
|
||||||
cfg.Matrix.ServerName,
|
cfg.Matrix.ServerName,
|
||||||
txnAndSessionID,
|
txnAndSessionID,
|
||||||
)
|
); err != nil {
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
|
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
util.GetLogger(req.Context()).WithFields(logrus.Fields{
|
util.GetLogger(req.Context()).WithFields(logrus.Fields{
|
||||||
"event_id": eventID,
|
"event_id": e.EventID(),
|
||||||
"room_id": roomID,
|
"room_id": roomID,
|
||||||
"room_version": verRes.RoomVersion,
|
"room_version": verRes.RoomVersion,
|
||||||
}).Info("Sent event to roomserver")
|
}).Info("Sent event to roomserver")
|
||||||
|
|
||||||
res := util.JSONResponse{
|
res := util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
JSON: sendEventResponse{eventID},
|
JSON: sendEventResponse{e.EventID()},
|
||||||
}
|
}
|
||||||
// Add response to transactionsCache
|
// Add response to transactionsCache
|
||||||
if txnID != nil {
|
if txnID != nil {
|
||||||
|
@ -158,7 +157,7 @@ func generateSendEvent(
|
||||||
}
|
}
|
||||||
|
|
||||||
var queryRes api.QueryLatestEventsAndStateResponse
|
var queryRes api.QueryLatestEventsAndStateResponse
|
||||||
e, err := eventutil.BuildEvent(req.Context(), &builder, cfg.Matrix, evTime, rsAPI, &queryRes)
|
e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, evTime, rsAPI, &queryRes)
|
||||||
if err == eventutil.ErrRoomNoExists {
|
if err == eventutil.ErrRoomNoExists {
|
||||||
return nil, &util.JSONResponse{
|
return nil, &util.JSONResponse{
|
||||||
Code: http.StatusNotFound,
|
Code: http.StatusNotFound,
|
||||||
|
|
|
@ -354,12 +354,12 @@ func emit3PIDInviteEvent(
|
||||||
}
|
}
|
||||||
|
|
||||||
queryRes := api.QueryLatestEventsAndStateResponse{}
|
queryRes := api.QueryLatestEventsAndStateResponse{}
|
||||||
event, err := eventutil.BuildEvent(ctx, builder, cfg.Matrix, evTime, rsAPI, &queryRes)
|
event, err := eventutil.QueryAndBuildEvent(ctx, builder, cfg.Matrix, evTime, rsAPI, &queryRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = api.SendEvents(
|
return api.SendEvents(
|
||||||
ctx, rsAPI,
|
ctx, rsAPI,
|
||||||
[]gomatrixserverlib.HeaderedEvent{
|
[]gomatrixserverlib.HeaderedEvent{
|
||||||
(*event).Headered(queryRes.RoomVersion),
|
(*event).Headered(queryRes.RoomVersion),
|
||||||
|
@ -367,5 +367,4 @@ func emit3PIDInviteEvent(
|
||||||
cfg.Matrix.ServerName,
|
cfg.Matrix.ServerName,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -155,7 +155,7 @@ func main() {
|
||||||
|
|
||||||
stateAPI := currentstateserver.NewInternalAPI(&base.Base.Cfg.CurrentStateServer, base.Base.KafkaConsumer)
|
stateAPI := currentstateserver.NewInternalAPI(&base.Base.Cfg.CurrentStateServer, base.Base.KafkaConsumer)
|
||||||
rsAPI := roomserver.NewInternalAPI(
|
rsAPI := roomserver.NewInternalAPI(
|
||||||
&base.Base, keyRing, federation,
|
&base.Base, keyRing,
|
||||||
)
|
)
|
||||||
eduInputAPI := eduserver.NewInternalAPI(
|
eduInputAPI := eduserver.NewInternalAPI(
|
||||||
&base.Base, cache.New(), userAPI,
|
&base.Base, cache.New(), userAPI,
|
||||||
|
|
|
@ -104,7 +104,7 @@ func main() {
|
||||||
keyAPI.SetUserAPI(userAPI)
|
keyAPI.SetUserAPI(userAPI)
|
||||||
|
|
||||||
rsComponent := roomserver.NewInternalAPI(
|
rsComponent := roomserver.NewInternalAPI(
|
||||||
base, keyRing, federation,
|
base, keyRing,
|
||||||
)
|
)
|
||||||
rsAPI := rsComponent
|
rsAPI := rsComponent
|
||||||
|
|
||||||
|
|
|
@ -81,7 +81,7 @@ func main() {
|
||||||
keyRing := serverKeyAPI.KeyRing()
|
keyRing := serverKeyAPI.KeyRing()
|
||||||
|
|
||||||
rsImpl := roomserver.NewInternalAPI(
|
rsImpl := roomserver.NewInternalAPI(
|
||||||
base, keyRing, federation,
|
base, keyRing,
|
||||||
)
|
)
|
||||||
// call functions directly on the impl unless running in HTTP mode
|
// call functions directly on the impl unless running in HTTP mode
|
||||||
rsAPI := rsImpl
|
rsAPI := rsImpl
|
||||||
|
|
|
@ -23,13 +23,12 @@ func main() {
|
||||||
cfg := setup.ParseFlags(false)
|
cfg := setup.ParseFlags(false)
|
||||||
base := setup.NewBaseDendrite(cfg, "RoomServerAPI", true)
|
base := setup.NewBaseDendrite(cfg, "RoomServerAPI", true)
|
||||||
defer base.Close() // nolint: errcheck
|
defer base.Close() // nolint: errcheck
|
||||||
federation := base.CreateFederationClient()
|
|
||||||
|
|
||||||
serverKeyAPI := base.ServerKeyAPIClient()
|
serverKeyAPI := base.ServerKeyAPIClient()
|
||||||
keyRing := serverKeyAPI.KeyRing()
|
keyRing := serverKeyAPI.KeyRing()
|
||||||
|
|
||||||
fsAPI := base.FederationSenderHTTPClient()
|
fsAPI := base.FederationSenderHTTPClient()
|
||||||
rsAPI := roomserver.NewInternalAPI(base, keyRing, federation)
|
rsAPI := roomserver.NewInternalAPI(base, keyRing)
|
||||||
rsAPI.SetFederationSenderAPI(fsAPI)
|
rsAPI.SetFederationSenderAPI(fsAPI)
|
||||||
roomserver.AddInternalRoutes(base.InternalAPIMux, rsAPI)
|
roomserver.AddInternalRoutes(base.InternalAPIMux, rsAPI)
|
||||||
|
|
||||||
|
|
|
@ -205,7 +205,7 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer)
|
stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer)
|
||||||
rsAPI := roomserver.NewInternalAPI(base, keyRing, federation)
|
rsAPI := roomserver.NewInternalAPI(base, keyRing)
|
||||||
eduInputAPI := eduserver.NewInternalAPI(base, cache.New(), userAPI)
|
eduInputAPI := eduserver.NewInternalAPI(base, cache.New(), userAPI)
|
||||||
asQuery := appservice.NewInternalAPI(
|
asQuery := appservice.NewInternalAPI(
|
||||||
base, userAPI, rsAPI,
|
base, userAPI, rsAPI,
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/config"
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
|
@ -8,6 +9,9 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
defaultsForCI := flag.Bool("ci", false, "sane defaults for CI testing")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
cfg := &config.Dendrite{}
|
cfg := &config.Dendrite{}
|
||||||
cfg.Defaults()
|
cfg.Defaults()
|
||||||
cfg.Global.TrustedIDServers = []string{
|
cfg.Global.TrustedIDServers = []string{
|
||||||
|
@ -56,6 +60,11 @@ func main() {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if *defaultsForCI {
|
||||||
|
cfg.ClientAPI.RateLimiting.Enabled = false
|
||||||
|
cfg.FederationSender.DisableTLSValidation = true
|
||||||
|
}
|
||||||
|
|
||||||
j, err := yaml.Marshal(cfg)
|
j, err := yaml.Marshal(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
|
|
@ -23,17 +23,25 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/currentstateserver/storage"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ServerACLDatabase interface {
|
||||||
|
// GetKnownRooms returns a list of all rooms we know about.
|
||||||
|
GetKnownRooms(ctx context.Context) ([]string, error)
|
||||||
|
// GetStateEvent returns the state event of a given type for a given room with a given state key
|
||||||
|
// If no event could be found, returns nil
|
||||||
|
// If there was an issue during the retrieval, returns an error
|
||||||
|
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
|
||||||
|
}
|
||||||
|
|
||||||
type ServerACLs struct {
|
type ServerACLs struct {
|
||||||
acls map[string]*serverACL // room ID -> ACL
|
acls map[string]*serverACL // room ID -> ACL
|
||||||
aclsMutex sync.RWMutex // protects the above
|
aclsMutex sync.RWMutex // protects the above
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServerACLs(db storage.Database) *ServerACLs {
|
func NewServerACLs(db ServerACLDatabase) *ServerACLs {
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
acls := &ServerACLs{
|
acls := &ServerACLs{
|
||||||
acls: make(map[string]*serverACL),
|
acls: make(map[string]*serverACL),
|
||||||
|
|
|
@ -133,6 +133,14 @@ client_api:
|
||||||
turn_username: ""
|
turn_username: ""
|
||||||
turn_password: ""
|
turn_password: ""
|
||||||
|
|
||||||
|
# Settings for rate-limited endpoints. Rate limiting will kick in after the
|
||||||
|
# threshold number of "slots" have been taken by requests from a specific
|
||||||
|
# host. Each "slot" will be released after the cooloff time in milliseconds.
|
||||||
|
rate_limiting:
|
||||||
|
enabled: true
|
||||||
|
threshold: 5
|
||||||
|
cooloff_ms: 500
|
||||||
|
|
||||||
# Configuration for the Current State Server.
|
# Configuration for the Current State Server.
|
||||||
current_state_server:
|
current_state_server:
|
||||||
internal_api:
|
internal_api:
|
||||||
|
|
|
@ -95,7 +95,7 @@ func MakeJoin(
|
||||||
queryRes := api.QueryLatestEventsAndStateResponse{
|
queryRes := api.QueryLatestEventsAndStateResponse{
|
||||||
RoomVersion: verRes.RoomVersion,
|
RoomVersion: verRes.RoomVersion,
|
||||||
}
|
}
|
||||||
event, err := eventutil.BuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes)
|
event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes)
|
||||||
if err == eventutil.ErrRoomNoExists {
|
if err == eventutil.ErrRoomNoExists {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusNotFound,
|
Code: http.StatusNotFound,
|
||||||
|
@ -266,15 +266,14 @@ func SendJoin(
|
||||||
// We are responsible for notifying other servers that the user has joined
|
// We are responsible for notifying other servers that the user has joined
|
||||||
// the room, so set SendAsServer to cfg.Matrix.ServerName
|
// the room, so set SendAsServer to cfg.Matrix.ServerName
|
||||||
if !alreadyJoined {
|
if !alreadyJoined {
|
||||||
_, err = api.SendEvents(
|
if err = api.SendEvents(
|
||||||
httpReq.Context(), rsAPI,
|
httpReq.Context(), rsAPI,
|
||||||
[]gomatrixserverlib.HeaderedEvent{
|
[]gomatrixserverlib.HeaderedEvent{
|
||||||
event.Headered(stateAndAuthChainResponse.RoomVersion),
|
event.Headered(stateAndAuthChainResponse.RoomVersion),
|
||||||
},
|
},
|
||||||
cfg.Matrix.ServerName,
|
cfg.Matrix.ServerName,
|
||||||
nil,
|
nil,
|
||||||
)
|
); err != nil {
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(httpReq.Context()).WithError(err).Error("SendEvents failed")
|
util.GetLogger(httpReq.Context()).WithError(err).Error("SendEvents failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,7 +61,7 @@ func MakeLeave(
|
||||||
}
|
}
|
||||||
|
|
||||||
var queryRes api.QueryLatestEventsAndStateResponse
|
var queryRes api.QueryLatestEventsAndStateResponse
|
||||||
event, err := eventutil.BuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes)
|
event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes)
|
||||||
if err == eventutil.ErrRoomNoExists {
|
if err == eventutil.ErrRoomNoExists {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusNotFound,
|
Code: http.StatusNotFound,
|
||||||
|
@ -247,15 +247,14 @@ func SendLeave(
|
||||||
// Send the events to the room server.
|
// Send the events to the room server.
|
||||||
// We are responsible for notifying other servers that the user has left
|
// We are responsible for notifying other servers that the user has left
|
||||||
// the room, so set SendAsServer to cfg.Matrix.ServerName
|
// the room, so set SendAsServer to cfg.Matrix.ServerName
|
||||||
_, err = api.SendEvents(
|
if err = api.SendEvents(
|
||||||
httpReq.Context(), rsAPI,
|
httpReq.Context(), rsAPI,
|
||||||
[]gomatrixserverlib.HeaderedEvent{
|
[]gomatrixserverlib.HeaderedEvent{
|
||||||
event.Headered(verRes.RoomVersion),
|
event.Headered(verRes.RoomVersion),
|
||||||
},
|
},
|
||||||
cfg.Matrix.ServerName,
|
cfg.Matrix.ServerName,
|
||||||
nil,
|
nil,
|
||||||
)
|
); err != nil {
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(httpReq.Context()).WithError(err).Error("producer.SendEvents failed")
|
util.GetLogger(httpReq.Context()).WithError(err).Error("producer.SendEvents failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
|
|
@ -382,7 +382,7 @@ func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) erro
|
||||||
}
|
}
|
||||||
|
|
||||||
// pass the event to the roomserver
|
// pass the event to the roomserver
|
||||||
_, err := api.SendEvents(
|
return api.SendEvents(
|
||||||
t.context, t.rsAPI,
|
t.context, t.rsAPI,
|
||||||
[]gomatrixserverlib.HeaderedEvent{
|
[]gomatrixserverlib.HeaderedEvent{
|
||||||
e.Headered(stateResp.RoomVersion),
|
e.Headered(stateResp.RoomVersion),
|
||||||
|
@ -390,7 +390,6 @@ func (t *txnReq) processEvent(e gomatrixserverlib.Event, isInboundTxn bool) erro
|
||||||
api.DoNotSendToOtherServers,
|
api.DoNotSendToOtherServers,
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserverlib.Event) error {
|
func checkAllowedByState(e gomatrixserverlib.Event, stateEvents []gomatrixserverlib.Event) error {
|
||||||
|
|
|
@ -296,6 +296,30 @@ func (t *testRoomserverAPI) RemoveRoomAlias(
|
||||||
return fmt.Errorf("not implemented")
|
return fmt.Errorf("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *testRoomserverAPI) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error {
|
||||||
|
return fmt.Errorf("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testRoomserverAPI) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error {
|
||||||
|
return fmt.Errorf("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testRoomserverAPI) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
|
||||||
|
return fmt.Errorf("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testRoomserverAPI) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error {
|
||||||
|
return fmt.Errorf("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testRoomserverAPI) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type testStateAPI struct {
|
type testStateAPI struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -89,7 +89,7 @@ func CreateInvitesFrom3PIDInvites(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send all the events
|
// Send all the events
|
||||||
if _, err := api.SendEvents(req.Context(), rsAPI, evs, cfg.Matrix.ServerName, nil); err != nil {
|
if err := api.SendEvents(req.Context(), rsAPI, evs, cfg.Matrix.ServerName, nil); err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
|
util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed")
|
||||||
return jsonerror.InternalServerError()
|
return jsonerror.InternalServerError()
|
||||||
}
|
}
|
||||||
|
@ -172,7 +172,7 @@ func ExchangeThirdPartyInvite(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send the event to the roomserver
|
// Send the event to the roomserver
|
||||||
if _, err = api.SendEvents(
|
if err = api.SendEvents(
|
||||||
httpReq.Context(), rsAPI,
|
httpReq.Context(), rsAPI,
|
||||||
[]gomatrixserverlib.HeaderedEvent{
|
[]gomatrixserverlib.HeaderedEvent{
|
||||||
signedEvent.Event.Headered(verRes.RoomVersion),
|
signedEvent.Event.Headered(verRes.RoomVersion),
|
||||||
|
|
|
@ -17,6 +17,7 @@ package routing
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -31,5 +32,13 @@ type server struct {
|
||||||
|
|
||||||
// Version returns the server version
|
// Version returns the server version
|
||||||
func Version() util.JSONResponse {
|
func Version() util.JSONResponse {
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &version{server{"dev", "Dendrite"}}}
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusOK,
|
||||||
|
JSON: &version{
|
||||||
|
server{
|
||||||
|
Name: "Dendrite",
|
||||||
|
Version: internal.VersionString(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,9 +14,12 @@ import (
|
||||||
// implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in
|
// implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in
|
||||||
// this interface are of type FederationClientError
|
// this interface are of type FederationClientError
|
||||||
type FederationClient interface {
|
type FederationClient interface {
|
||||||
|
gomatrixserverlib.BackfillClient
|
||||||
|
gomatrixserverlib.FederatedStateClient
|
||||||
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error)
|
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error)
|
||||||
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error)
|
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error)
|
||||||
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
|
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
|
||||||
|
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FederationClientError is returned from FederationClient methods in the event of a problem.
|
// FederationClientError is returned from FederationClient methods in the event of a problem.
|
||||||
|
|
|
@ -136,3 +136,51 @@ func (a *FederationSenderInternalAPI) QueryKeys(
|
||||||
}
|
}
|
||||||
return ires.(gomatrixserverlib.RespQueryKeys), nil
|
return ires.(gomatrixserverlib.RespQueryKeys), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *FederationSenderInternalAPI) Backfill(
|
||||||
|
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string,
|
||||||
|
) (res gomatrixserverlib.Transaction, err error) {
|
||||||
|
ires, err := a.doRequest(s, func() (interface{}, error) {
|
||||||
|
return a.federation.Backfill(ctx, s, roomID, limit, eventIDs)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return gomatrixserverlib.Transaction{}, err
|
||||||
|
}
|
||||||
|
return ires.(gomatrixserverlib.Transaction), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *FederationSenderInternalAPI) LookupState(
|
||||||
|
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
|
||||||
|
) (res gomatrixserverlib.RespState, err error) {
|
||||||
|
ires, err := a.doRequest(s, func() (interface{}, error) {
|
||||||
|
return a.federation.LookupState(ctx, s, roomID, eventID, roomVersion)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return gomatrixserverlib.RespState{}, err
|
||||||
|
}
|
||||||
|
return ires.(gomatrixserverlib.RespState), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *FederationSenderInternalAPI) LookupStateIDs(
|
||||||
|
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string,
|
||||||
|
) (res gomatrixserverlib.RespStateIDs, err error) {
|
||||||
|
ires, err := a.doRequest(s, func() (interface{}, error) {
|
||||||
|
return a.federation.LookupStateIDs(ctx, s, roomID, eventID)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return gomatrixserverlib.RespStateIDs{}, err
|
||||||
|
}
|
||||||
|
return ires.(gomatrixserverlib.RespStateIDs), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *FederationSenderInternalAPI) GetEvent(
|
||||||
|
ctx context.Context, s gomatrixserverlib.ServerName, eventID string,
|
||||||
|
) (res gomatrixserverlib.Transaction, err error) {
|
||||||
|
ires, err := a.doRequest(s, func() (interface{}, error) {
|
||||||
|
return a.federation.GetEvent(ctx, s, eventID)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return gomatrixserverlib.Transaction{}, err
|
||||||
|
}
|
||||||
|
return ires.(gomatrixserverlib.Transaction), nil
|
||||||
|
}
|
||||||
|
|
|
@ -26,6 +26,10 @@ const (
|
||||||
FederationSenderGetUserDevicesPath = "/federationsender/client/getUserDevices"
|
FederationSenderGetUserDevicesPath = "/federationsender/client/getUserDevices"
|
||||||
FederationSenderClaimKeysPath = "/federationsender/client/claimKeys"
|
FederationSenderClaimKeysPath = "/federationsender/client/claimKeys"
|
||||||
FederationSenderQueryKeysPath = "/federationsender/client/queryKeys"
|
FederationSenderQueryKeysPath = "/federationsender/client/queryKeys"
|
||||||
|
FederationSenderBackfillPath = "/federationsender/client/backfill"
|
||||||
|
FederationSenderLookupStatePath = "/federationsender/client/lookupState"
|
||||||
|
FederationSenderLookupStateIDsPath = "/federationsender/client/lookupStateIDs"
|
||||||
|
FederationSenderGetEventPath = "/federationsender/client/getEvent"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewFederationSenderClient creates a FederationSenderInternalAPI implemented by talking to a HTTP POST API.
|
// NewFederationSenderClient creates a FederationSenderInternalAPI implemented by talking to a HTTP POST API.
|
||||||
|
@ -228,3 +232,129 @@ func (h *httpFederationSenderInternalAPI) QueryKeys(
|
||||||
}
|
}
|
||||||
return *response.Res, nil
|
return *response.Res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type backfill struct {
|
||||||
|
S gomatrixserverlib.ServerName
|
||||||
|
RoomID string
|
||||||
|
Limit int
|
||||||
|
EventIDs []string
|
||||||
|
Res *gomatrixserverlib.Transaction
|
||||||
|
Err *api.FederationClientError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpFederationSenderInternalAPI) Backfill(
|
||||||
|
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string,
|
||||||
|
) (gomatrixserverlib.Transaction, error) {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "Backfill")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
request := backfill{
|
||||||
|
S: s,
|
||||||
|
RoomID: roomID,
|
||||||
|
Limit: limit,
|
||||||
|
EventIDs: eventIDs,
|
||||||
|
}
|
||||||
|
var response backfill
|
||||||
|
apiURL := h.federationSenderURL + FederationSenderBackfillPath
|
||||||
|
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
|
||||||
|
if err != nil {
|
||||||
|
return gomatrixserverlib.Transaction{}, err
|
||||||
|
}
|
||||||
|
if response.Err != nil {
|
||||||
|
return gomatrixserverlib.Transaction{}, response.Err
|
||||||
|
}
|
||||||
|
return *response.Res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type lookupState struct {
|
||||||
|
S gomatrixserverlib.ServerName
|
||||||
|
RoomID string
|
||||||
|
EventID string
|
||||||
|
RoomVersion gomatrixserverlib.RoomVersion
|
||||||
|
Res *gomatrixserverlib.RespState
|
||||||
|
Err *api.FederationClientError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpFederationSenderInternalAPI) LookupState(
|
||||||
|
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
|
||||||
|
) (gomatrixserverlib.RespState, error) {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupState")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
request := lookupState{
|
||||||
|
S: s,
|
||||||
|
RoomID: roomID,
|
||||||
|
EventID: eventID,
|
||||||
|
RoomVersion: roomVersion,
|
||||||
|
}
|
||||||
|
var response lookupState
|
||||||
|
apiURL := h.federationSenderURL + FederationSenderLookupStatePath
|
||||||
|
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
|
||||||
|
if err != nil {
|
||||||
|
return gomatrixserverlib.RespState{}, err
|
||||||
|
}
|
||||||
|
if response.Err != nil {
|
||||||
|
return gomatrixserverlib.RespState{}, response.Err
|
||||||
|
}
|
||||||
|
return *response.Res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type lookupStateIDs struct {
|
||||||
|
S gomatrixserverlib.ServerName
|
||||||
|
RoomID string
|
||||||
|
EventID string
|
||||||
|
Res *gomatrixserverlib.RespStateIDs
|
||||||
|
Err *api.FederationClientError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpFederationSenderInternalAPI) LookupStateIDs(
|
||||||
|
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string,
|
||||||
|
) (gomatrixserverlib.RespStateIDs, error) {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupStateIDs")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
request := lookupStateIDs{
|
||||||
|
S: s,
|
||||||
|
RoomID: roomID,
|
||||||
|
EventID: eventID,
|
||||||
|
}
|
||||||
|
var response lookupStateIDs
|
||||||
|
apiURL := h.federationSenderURL + FederationSenderLookupStateIDsPath
|
||||||
|
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
|
||||||
|
if err != nil {
|
||||||
|
return gomatrixserverlib.RespStateIDs{}, err
|
||||||
|
}
|
||||||
|
if response.Err != nil {
|
||||||
|
return gomatrixserverlib.RespStateIDs{}, response.Err
|
||||||
|
}
|
||||||
|
return *response.Res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type getEvent struct {
|
||||||
|
S gomatrixserverlib.ServerName
|
||||||
|
EventID string
|
||||||
|
Res *gomatrixserverlib.Transaction
|
||||||
|
Err *api.FederationClientError
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpFederationSenderInternalAPI) GetEvent(
|
||||||
|
ctx context.Context, s gomatrixserverlib.ServerName, eventID string,
|
||||||
|
) (gomatrixserverlib.Transaction, error) {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "GetEvent")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
request := getEvent{
|
||||||
|
S: s,
|
||||||
|
EventID: eventID,
|
||||||
|
}
|
||||||
|
var response getEvent
|
||||||
|
apiURL := h.federationSenderURL + FederationSenderGetEventPath
|
||||||
|
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
|
||||||
|
if err != nil {
|
||||||
|
return gomatrixserverlib.Transaction{}, err
|
||||||
|
}
|
||||||
|
if response.Err != nil {
|
||||||
|
return gomatrixserverlib.Transaction{}, response.Err
|
||||||
|
}
|
||||||
|
return *response.Res, nil
|
||||||
|
}
|
||||||
|
|
|
@ -175,4 +175,92 @@ func AddRoutes(intAPI api.FederationSenderInternalAPI, internalAPIMux *mux.Route
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: request}
|
return util.JSONResponse{Code: http.StatusOK, JSON: request}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
internalAPIMux.Handle(
|
||||||
|
FederationSenderBackfillPath,
|
||||||
|
httputil.MakeInternalAPI("Backfill", func(req *http.Request) util.JSONResponse {
|
||||||
|
var request backfill
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
res, err := intAPI.Backfill(req.Context(), request.S, request.RoomID, request.Limit, request.EventIDs)
|
||||||
|
if err != nil {
|
||||||
|
ferr, ok := err.(*api.FederationClientError)
|
||||||
|
if ok {
|
||||||
|
request.Err = ferr
|
||||||
|
} else {
|
||||||
|
request.Err = &api.FederationClientError{
|
||||||
|
Err: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
request.Res = &res
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: request}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
internalAPIMux.Handle(
|
||||||
|
FederationSenderLookupStatePath,
|
||||||
|
httputil.MakeInternalAPI("LookupState", func(req *http.Request) util.JSONResponse {
|
||||||
|
var request lookupState
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
res, err := intAPI.LookupState(req.Context(), request.S, request.RoomID, request.EventID, request.RoomVersion)
|
||||||
|
if err != nil {
|
||||||
|
ferr, ok := err.(*api.FederationClientError)
|
||||||
|
if ok {
|
||||||
|
request.Err = ferr
|
||||||
|
} else {
|
||||||
|
request.Err = &api.FederationClientError{
|
||||||
|
Err: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
request.Res = &res
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: request}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
internalAPIMux.Handle(
|
||||||
|
FederationSenderLookupStateIDsPath,
|
||||||
|
httputil.MakeInternalAPI("LookupStateIDs", func(req *http.Request) util.JSONResponse {
|
||||||
|
var request lookupStateIDs
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
res, err := intAPI.LookupStateIDs(req.Context(), request.S, request.RoomID, request.EventID)
|
||||||
|
if err != nil {
|
||||||
|
ferr, ok := err.(*api.FederationClientError)
|
||||||
|
if ok {
|
||||||
|
request.Err = ferr
|
||||||
|
} else {
|
||||||
|
request.Err = &api.FederationClientError{
|
||||||
|
Err: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
request.Res = &res
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: request}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
internalAPIMux.Handle(
|
||||||
|
FederationSenderGetEventPath,
|
||||||
|
httputil.MakeInternalAPI("GetEvent", func(req *http.Request) util.JSONResponse {
|
||||||
|
var request getEvent
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
res, err := intAPI.GetEvent(req.Context(), request.S, request.EventID)
|
||||||
|
if err != nil {
|
||||||
|
ferr, ok := err.(*api.FederationClientError)
|
||||||
|
if ok {
|
||||||
|
request.Err = ferr
|
||||||
|
} else {
|
||||||
|
request.Err = &api.FederationClientError{
|
||||||
|
Err: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
request.Res = &res
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: request}
|
||||||
|
}),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -21,7 +21,7 @@ require (
|
||||||
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
|
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd
|
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200817100842-9d02141812f2
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20200902135805-f7a5b5e89750
|
||||||
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91
|
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91
|
||||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||||
github.com/mattn/go-sqlite3 v1.14.2
|
github.com/mattn/go-sqlite3 v1.14.2
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -567,8 +567,8 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg=
|
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200817100842-9d02141812f2 h1:9wKwfd5KDcXuqZ7/kAaYe0QM4DGM+2awjjvXQtrDa6k=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20200902135805-f7a5b5e89750 h1:k5vsLfpylXHOXgN51N0QNbak9i+4bT33Puk/ZJgcdDw=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20200817100842-9d02141812f2/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20200902135805-f7a5b5e89750/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
||||||
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4=
|
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4=
|
||||||
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE=
|
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE=
|
||||||
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=
|
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=
|
||||||
|
|
|
@ -34,6 +34,9 @@ type ClientAPI struct {
|
||||||
|
|
||||||
// TURN options
|
// TURN options
|
||||||
TURN TURN `yaml:"turn"`
|
TURN TURN `yaml:"turn"`
|
||||||
|
|
||||||
|
// Rate-limiting options
|
||||||
|
RateLimiting RateLimiting `yaml:"rate_limiting"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientAPI) Defaults() {
|
func (c *ClientAPI) Defaults() {
|
||||||
|
@ -47,6 +50,7 @@ func (c *ClientAPI) Defaults() {
|
||||||
c.RecaptchaBypassSecret = ""
|
c.RecaptchaBypassSecret = ""
|
||||||
c.RecaptchaSiteVerifyAPI = ""
|
c.RecaptchaSiteVerifyAPI = ""
|
||||||
c.RegistrationDisabled = false
|
c.RegistrationDisabled = false
|
||||||
|
c.RateLimiting.Defaults()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
||||||
|
@ -61,6 +65,7 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) {
|
||||||
checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", string(c.RecaptchaSiteVerifyAPI))
|
checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", string(c.RecaptchaSiteVerifyAPI))
|
||||||
}
|
}
|
||||||
c.TURN.Verify(configErrs)
|
c.TURN.Verify(configErrs)
|
||||||
|
c.RateLimiting.Verify(configErrs)
|
||||||
}
|
}
|
||||||
|
|
||||||
type TURN struct {
|
type TURN struct {
|
||||||
|
@ -90,3 +95,29 @@ func (c *TURN) Verify(configErrs *ConfigErrors) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RateLimiting struct {
|
||||||
|
// Is rate limiting enabled or disabled?
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
|
||||||
|
// How many "slots" a user can occupy sending requests to a rate-limited
|
||||||
|
// endpoint before we apply rate-limiting
|
||||||
|
Threshold int64 `yaml:"threshold"`
|
||||||
|
|
||||||
|
// The cooloff period in milliseconds after a request before the "slot"
|
||||||
|
// is freed again
|
||||||
|
CooloffMS int64 `yaml:"cooloff_ms"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RateLimiting) Verify(configErrs *ConfigErrors) {
|
||||||
|
if r.Enabled {
|
||||||
|
checkPositive(configErrs, "client_api.rate_limiting.threshold", r.Threshold)
|
||||||
|
checkPositive(configErrs, "client_api.rate_limiting.cooloff_ms", r.CooloffMS)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RateLimiting) Defaults() {
|
||||||
|
r.Enabled = true
|
||||||
|
r.Threshold = 5
|
||||||
|
r.CooloffMS = 500
|
||||||
|
}
|
||||||
|
|
|
@ -30,13 +30,13 @@ import (
|
||||||
// doesn't exist
|
// doesn't exist
|
||||||
var ErrRoomNoExists = errors.New("Room does not exist")
|
var ErrRoomNoExists = errors.New("Room does not exist")
|
||||||
|
|
||||||
// BuildEvent builds a Matrix event using the event builder and roomserver query
|
// QueryAndBuildEvent builds a Matrix event using the event builder and roomserver query
|
||||||
// API client provided. If also fills roomserver query API response (if provided)
|
// API client provided. If also fills roomserver query API response (if provided)
|
||||||
// in case the function calling FillBuilder needs to use it.
|
// in case the function calling FillBuilder needs to use it.
|
||||||
// Returns ErrRoomNoExists if the state of the room could not be retrieved because
|
// Returns ErrRoomNoExists if the state of the room could not be retrieved because
|
||||||
// the room doesn't exist
|
// the room doesn't exist
|
||||||
// Returns an error if something else went wrong
|
// Returns an error if something else went wrong
|
||||||
func BuildEvent(
|
func QueryAndBuildEvent(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
builder *gomatrixserverlib.EventBuilder, cfg *config.Global, evTime time.Time,
|
builder *gomatrixserverlib.EventBuilder, cfg *config.Global, evTime time.Time,
|
||||||
rsAPI api.RoomserverInternalAPI, queryRes *api.QueryLatestEventsAndStateResponse,
|
rsAPI api.RoomserverInternalAPI, queryRes *api.QueryLatestEventsAndStateResponse,
|
||||||
|
@ -45,11 +45,25 @@ func BuildEvent(
|
||||||
queryRes = &api.QueryLatestEventsAndStateResponse{}
|
queryRes = &api.QueryLatestEventsAndStateResponse{}
|
||||||
}
|
}
|
||||||
|
|
||||||
ver, err := AddPrevEventsToEvent(ctx, builder, rsAPI, queryRes)
|
eventsNeeded, err := queryRequiredEventsForBuilder(ctx, builder, rsAPI, queryRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// This can pass through a ErrRoomNoExists to the caller
|
// This can pass through a ErrRoomNoExists to the caller
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return BuildEvent(ctx, builder, cfg, evTime, eventsNeeded, queryRes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildEvent builds a Matrix event from the builder and QueryLatestEventsAndStateResponse
|
||||||
|
// provided.
|
||||||
|
func BuildEvent(
|
||||||
|
ctx context.Context,
|
||||||
|
builder *gomatrixserverlib.EventBuilder, cfg *config.Global, evTime time.Time,
|
||||||
|
eventsNeeded *gomatrixserverlib.StateNeeded, queryRes *api.QueryLatestEventsAndStateResponse,
|
||||||
|
) (*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
|
err := addPrevEventsToEvent(builder, eventsNeeded, queryRes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
event, err := builder.Build(
|
event, err := builder.Build(
|
||||||
evTime, cfg.ServerName, cfg.KeyID,
|
evTime, cfg.ServerName, cfg.KeyID,
|
||||||
|
@ -59,23 +73,23 @@ func BuildEvent(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
h := event.Headered(ver)
|
h := event.Headered(queryRes.RoomVersion)
|
||||||
return &h, nil
|
return &h, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddPrevEventsToEvent fills out the prev_events and auth_events fields in builder
|
// queryRequiredEventsForBuilder queries the roomserver for auth/prev events needed for this builder.
|
||||||
func AddPrevEventsToEvent(
|
func queryRequiredEventsForBuilder(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
builder *gomatrixserverlib.EventBuilder,
|
builder *gomatrixserverlib.EventBuilder,
|
||||||
rsAPI api.RoomserverInternalAPI, queryRes *api.QueryLatestEventsAndStateResponse,
|
rsAPI api.RoomserverInternalAPI, queryRes *api.QueryLatestEventsAndStateResponse,
|
||||||
) (gomatrixserverlib.RoomVersion, error) {
|
) (*gomatrixserverlib.StateNeeded, error) {
|
||||||
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
|
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("gomatrixserverlib.StateNeededForEventBuilder: %w", err)
|
return nil, fmt.Errorf("gomatrixserverlib.StateNeededForEventBuilder: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(eventsNeeded.Tuples()) == 0 {
|
if len(eventsNeeded.Tuples()) == 0 {
|
||||||
return "", errors.New("expecting state tuples for event builder, got none")
|
return nil, errors.New("expecting state tuples for event builder, got none")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ask the roomserver for information about this room
|
// Ask the roomserver for information about this room
|
||||||
|
@ -83,17 +97,22 @@ func AddPrevEventsToEvent(
|
||||||
RoomID: builder.RoomID,
|
RoomID: builder.RoomID,
|
||||||
StateToFetch: eventsNeeded.Tuples(),
|
StateToFetch: eventsNeeded.Tuples(),
|
||||||
}
|
}
|
||||||
if err = rsAPI.QueryLatestEventsAndState(ctx, &queryReq, queryRes); err != nil {
|
return &eventsNeeded, rsAPI.QueryLatestEventsAndState(ctx, &queryReq, queryRes)
|
||||||
return "", fmt.Errorf("rsAPI.QueryLatestEventsAndState: %w", err)
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
// addPrevEventsToEvent fills out the prev_events and auth_events fields in builder
|
||||||
|
func addPrevEventsToEvent(
|
||||||
|
builder *gomatrixserverlib.EventBuilder,
|
||||||
|
eventsNeeded *gomatrixserverlib.StateNeeded,
|
||||||
|
queryRes *api.QueryLatestEventsAndStateResponse,
|
||||||
|
) error {
|
||||||
if !queryRes.RoomExists {
|
if !queryRes.RoomExists {
|
||||||
return "", ErrRoomNoExists
|
return ErrRoomNoExists
|
||||||
}
|
}
|
||||||
|
|
||||||
eventFormat, err := queryRes.RoomVersion.EventFormat()
|
eventFormat, err := queryRes.RoomVersion.EventFormat()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("queryRes.RoomVersion.EventFormat: %w", err)
|
return fmt.Errorf("queryRes.RoomVersion.EventFormat: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
builder.Depth = queryRes.Depth
|
builder.Depth = queryRes.Depth
|
||||||
|
@ -103,13 +122,13 @@ func AddPrevEventsToEvent(
|
||||||
for i := range queryRes.StateEvents {
|
for i := range queryRes.StateEvents {
|
||||||
err = authEvents.AddEvent(&queryRes.StateEvents[i].Event)
|
err = authEvents.AddEvent(&queryRes.StateEvents[i].Event)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("authEvents.AddEvent: %w", err)
|
return fmt.Errorf("authEvents.AddEvent: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
refs, err := eventsNeeded.AuthEventReferences(&authEvents)
|
refs, err := eventsNeeded.AuthEventReferences(&authEvents)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("eventsNeeded.AuthEventReferences: %w", err)
|
return fmt.Errorf("eventsNeeded.AuthEventReferences: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
truncAuth, truncPrev := truncateAuthAndPrevEvents(refs, queryRes.LatestEvents)
|
truncAuth, truncPrev := truncateAuthAndPrevEvents(refs, queryRes.LatestEvents)
|
||||||
|
@ -129,7 +148,7 @@ func AddPrevEventsToEvent(
|
||||||
builder.PrevEvents = v2PrevRefs
|
builder.PrevEvents = v2PrevRefs
|
||||||
}
|
}
|
||||||
|
|
||||||
return queryRes.RoomVersion, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// truncateAuthAndPrevEvents limits the number of events we add into
|
// truncateAuthAndPrevEvents limits the number of events we add into
|
||||||
|
|
|
@ -100,6 +100,8 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, useHTTPAPIs boo
|
||||||
internal.SetupHookLogging(cfg.Logging, componentName)
|
internal.SetupHookLogging(cfg.Logging, componentName)
|
||||||
internal.SetupPprof()
|
internal.SetupPprof()
|
||||||
|
|
||||||
|
logrus.Infof("Dendrite version %s", internal.VersionString())
|
||||||
|
|
||||||
closer, err := cfg.SetupTracing("Dendrite" + componentName)
|
closer, err := cfg.SetupTracing("Dendrite" + componentName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Panicf("failed to start opentracing")
|
logrus.WithError(err).Panicf("failed to start opentracing")
|
||||||
|
|
26
internal/version.go
Normal file
26
internal/version.go
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// -ldflags "-X github.com/matrix-org/dendrite/internal.branch=master"
|
||||||
|
var branch string
|
||||||
|
|
||||||
|
// -ldflags "-X github.com/matrix-org/dendrite/internal.build=alpha"
|
||||||
|
var build string
|
||||||
|
|
||||||
|
const (
|
||||||
|
VersionMajor = 0
|
||||||
|
VersionMinor = 0
|
||||||
|
VersionPatch = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
func VersionString() string {
|
||||||
|
version := fmt.Sprintf("%d.%d.%d", VersionMajor, VersionMinor, VersionPatch)
|
||||||
|
if branch != "" {
|
||||||
|
version += fmt.Sprintf("-%s", branch)
|
||||||
|
}
|
||||||
|
if build != "" {
|
||||||
|
version += fmt.Sprintf("+%s", build)
|
||||||
|
}
|
||||||
|
return version
|
||||||
|
}
|
164
roomserver/acls/acls.go
Normal file
164
roomserver/acls/acls.go
Normal file
|
@ -0,0 +1,164 @@
|
||||||
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package acls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ServerACLDatabase interface {
|
||||||
|
// GetKnownRooms returns a list of all rooms we know about.
|
||||||
|
GetKnownRooms(ctx context.Context) ([]string, error)
|
||||||
|
// GetStateEvent returns the state event of a given type for a given room with a given state key
|
||||||
|
// If no event could be found, returns nil
|
||||||
|
// If there was an issue during the retrieval, returns an error
|
||||||
|
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerACLs struct {
|
||||||
|
acls map[string]*serverACL // room ID -> ACL
|
||||||
|
aclsMutex sync.RWMutex // protects the above
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewServerACLs(db ServerACLDatabase) *ServerACLs {
|
||||||
|
ctx := context.TODO()
|
||||||
|
acls := &ServerACLs{
|
||||||
|
acls: make(map[string]*serverACL),
|
||||||
|
}
|
||||||
|
// Look up all of the rooms that the current state server knows about.
|
||||||
|
rooms, err := db.GetKnownRooms(ctx)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Fatalf("Failed to get known rooms")
|
||||||
|
}
|
||||||
|
// For each room, let's see if we have a server ACL state event. If we
|
||||||
|
// do then we'll process it into memory so that we have the regexes to
|
||||||
|
// hand.
|
||||||
|
for _, room := range rooms {
|
||||||
|
state, err := db.GetStateEvent(ctx, room, "m.room.server_acl", "")
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Errorf("Failed to get server ACLs for room %q", room)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if state != nil {
|
||||||
|
acls.OnServerACLUpdate(&state.Event)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return acls
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerACL struct {
|
||||||
|
Allowed []string `json:"allow"`
|
||||||
|
Denied []string `json:"deny"`
|
||||||
|
AllowIPLiterals bool `json:"allow_ip_literals"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type serverACL struct {
|
||||||
|
ServerACL
|
||||||
|
allowedRegexes []*regexp.Regexp
|
||||||
|
deniedRegexes []*regexp.Regexp
|
||||||
|
}
|
||||||
|
|
||||||
|
func compileACLRegex(orig string) (*regexp.Regexp, error) {
|
||||||
|
escaped := regexp.QuoteMeta(orig)
|
||||||
|
escaped = strings.Replace(escaped, "\\?", ".", -1)
|
||||||
|
escaped = strings.Replace(escaped, "\\*", ".*", -1)
|
||||||
|
return regexp.Compile(escaped)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerACLs) OnServerACLUpdate(state *gomatrixserverlib.Event) {
|
||||||
|
acls := &serverACL{}
|
||||||
|
if err := json.Unmarshal(state.Content(), &acls.ServerACL); err != nil {
|
||||||
|
logrus.WithError(err).Errorf("Failed to unmarshal state content for server ACLs")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// The spec calls only for * (zero or more chars) and ? (exactly one char)
|
||||||
|
// to be supported as wildcard components, so we will escape all of the regex
|
||||||
|
// special characters and then replace * and ? with their regex counterparts.
|
||||||
|
// https://matrix.org/docs/spec/client_server/r0.6.1#m-room-server-acl
|
||||||
|
for _, orig := range acls.Allowed {
|
||||||
|
if expr, err := compileACLRegex(orig); err != nil {
|
||||||
|
logrus.WithError(err).Errorf("Failed to compile allowed regex")
|
||||||
|
} else {
|
||||||
|
acls.allowedRegexes = append(acls.allowedRegexes, expr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, orig := range acls.Denied {
|
||||||
|
if expr, err := compileACLRegex(orig); err != nil {
|
||||||
|
logrus.WithError(err).Errorf("Failed to compile denied regex")
|
||||||
|
} else {
|
||||||
|
acls.deniedRegexes = append(acls.deniedRegexes, expr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"allow_ip_literals": acls.AllowIPLiterals,
|
||||||
|
"num_allowed": len(acls.allowedRegexes),
|
||||||
|
"num_denied": len(acls.deniedRegexes),
|
||||||
|
}).Debugf("Updating server ACLs for %q", state.RoomID())
|
||||||
|
s.aclsMutex.Lock()
|
||||||
|
defer s.aclsMutex.Unlock()
|
||||||
|
s.acls[state.RoomID()] = acls
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerACLs) IsServerBannedFromRoom(serverName gomatrixserverlib.ServerName, roomID string) bool {
|
||||||
|
s.aclsMutex.RLock()
|
||||||
|
// First of all check if we have an ACL for this room. If we don't then
|
||||||
|
// no servers are banned from the room.
|
||||||
|
acls, ok := s.acls[roomID]
|
||||||
|
if !ok {
|
||||||
|
s.aclsMutex.RUnlock()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s.aclsMutex.RUnlock()
|
||||||
|
// Split the host and port apart. This is because the spec calls on us to
|
||||||
|
// validate the hostname only in cases where the port is also present.
|
||||||
|
if serverNameOnly, _, err := net.SplitHostPort(string(serverName)); err == nil {
|
||||||
|
serverName = gomatrixserverlib.ServerName(serverNameOnly)
|
||||||
|
}
|
||||||
|
// Check if the hostname is an IPv4 or IPv6 literal. We cheat here by adding
|
||||||
|
// a /0 prefix length just to trick ParseCIDR into working. If we find that
|
||||||
|
// the server is an IP literal and we don't allow those then stop straight
|
||||||
|
// away.
|
||||||
|
if _, _, err := net.ParseCIDR(fmt.Sprintf("%s/0", serverName)); err == nil {
|
||||||
|
if !acls.AllowIPLiterals {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check if the hostname matches one of the denied regexes. If it does then
|
||||||
|
// the server is banned from the room.
|
||||||
|
for _, expr := range acls.deniedRegexes {
|
||||||
|
if expr.MatchString(string(serverName)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check if the hostname matches one of the allowed regexes. If it does then
|
||||||
|
// the server is NOT banned from the room.
|
||||||
|
for _, expr := range acls.allowedRegexes {
|
||||||
|
if expr.MatchString(string(serverName)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If we've got to this point then we haven't matched any regexes or an IP
|
||||||
|
// hostname if disallowed. The spec calls for default-deny here.
|
||||||
|
return true
|
||||||
|
}
|
105
roomserver/acls/acls_test.go
Normal file
105
roomserver/acls/acls_test.go
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package acls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOpenACLsWithBlacklist(t *testing.T) {
|
||||||
|
roomID := "!test:test.com"
|
||||||
|
allowRegex, err := compileACLRegex("*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf(err.Error())
|
||||||
|
}
|
||||||
|
denyRegex, err := compileACLRegex("foo.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
acls := ServerACLs{
|
||||||
|
acls: make(map[string]*serverACL),
|
||||||
|
}
|
||||||
|
|
||||||
|
acls.acls[roomID] = &serverACL{
|
||||||
|
ServerACL: ServerACL{
|
||||||
|
AllowIPLiterals: true,
|
||||||
|
},
|
||||||
|
allowedRegexes: []*regexp.Regexp{allowRegex},
|
||||||
|
deniedRegexes: []*regexp.Regexp{denyRegex},
|
||||||
|
}
|
||||||
|
|
||||||
|
if acls.IsServerBannedFromRoom("1.2.3.4", roomID) {
|
||||||
|
t.Fatal("Expected 1.2.3.4 to be allowed but wasn't")
|
||||||
|
}
|
||||||
|
if acls.IsServerBannedFromRoom("1.2.3.4:2345", roomID) {
|
||||||
|
t.Fatal("Expected 1.2.3.4:2345 to be allowed but wasn't")
|
||||||
|
}
|
||||||
|
if !acls.IsServerBannedFromRoom("foo.com", roomID) {
|
||||||
|
t.Fatal("Expected foo.com to be banned but wasn't")
|
||||||
|
}
|
||||||
|
if !acls.IsServerBannedFromRoom("foo.com:3456", roomID) {
|
||||||
|
t.Fatal("Expected foo.com:3456 to be banned but wasn't")
|
||||||
|
}
|
||||||
|
if acls.IsServerBannedFromRoom("bar.com", roomID) {
|
||||||
|
t.Fatal("Expected bar.com to be allowed but wasn't")
|
||||||
|
}
|
||||||
|
if acls.IsServerBannedFromRoom("bar.com:4567", roomID) {
|
||||||
|
t.Fatal("Expected bar.com:4567 to be allowed but wasn't")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultACLsWithWhitelist(t *testing.T) {
|
||||||
|
roomID := "!test:test.com"
|
||||||
|
allowRegex, err := compileACLRegex("foo.com")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
acls := ServerACLs{
|
||||||
|
acls: make(map[string]*serverACL),
|
||||||
|
}
|
||||||
|
|
||||||
|
acls.acls[roomID] = &serverACL{
|
||||||
|
ServerACL: ServerACL{
|
||||||
|
AllowIPLiterals: false,
|
||||||
|
},
|
||||||
|
allowedRegexes: []*regexp.Regexp{allowRegex},
|
||||||
|
deniedRegexes: []*regexp.Regexp{},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !acls.IsServerBannedFromRoom("1.2.3.4", roomID) {
|
||||||
|
t.Fatal("Expected 1.2.3.4 to be banned but wasn't")
|
||||||
|
}
|
||||||
|
if !acls.IsServerBannedFromRoom("1.2.3.4:2345", roomID) {
|
||||||
|
t.Fatal("Expected 1.2.3.4:2345 to be banned but wasn't")
|
||||||
|
}
|
||||||
|
if acls.IsServerBannedFromRoom("foo.com", roomID) {
|
||||||
|
t.Fatal("Expected foo.com to be allowed but wasn't")
|
||||||
|
}
|
||||||
|
if acls.IsServerBannedFromRoom("foo.com:3456", roomID) {
|
||||||
|
t.Fatal("Expected foo.com:3456 to be allowed but wasn't")
|
||||||
|
}
|
||||||
|
if !acls.IsServerBannedFromRoom("bar.com", roomID) {
|
||||||
|
t.Fatal("Expected bar.com to be allowed but wasn't")
|
||||||
|
}
|
||||||
|
if !acls.IsServerBannedFromRoom("baz.com", roomID) {
|
||||||
|
t.Fatal("Expected baz.com to be allowed but wasn't")
|
||||||
|
}
|
||||||
|
if !acls.IsServerBannedFromRoom("qux.com:4567", roomID) {
|
||||||
|
t.Fatal("Expected qux.com:4567 to be allowed but wasn't")
|
||||||
|
}
|
||||||
|
}
|
|
@ -112,6 +112,20 @@ type RoomserverInternalAPI interface {
|
||||||
response *QueryStateAndAuthChainResponse,
|
response *QueryStateAndAuthChainResponse,
|
||||||
) error
|
) error
|
||||||
|
|
||||||
|
// QueryCurrentState retrieves the requested state events. If state events are not found, they will be missing from
|
||||||
|
// the response.
|
||||||
|
QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error
|
||||||
|
// QueryRoomsForUser retrieves a list of room IDs matching the given query.
|
||||||
|
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
|
||||||
|
// QueryBulkStateContent does a bulk query for state event content in the given rooms.
|
||||||
|
QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error
|
||||||
|
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
|
||||||
|
QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
|
||||||
|
// QueryKnownUsers returns a list of users that we know about from our joined rooms.
|
||||||
|
QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error
|
||||||
|
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
|
||||||
|
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
|
||||||
|
|
||||||
// Query a given amount (or less) of events prior to a given set of events.
|
// Query a given amount (or less) of events prior to a given set of events.
|
||||||
PerformBackfill(
|
PerformBackfill(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
|
|
@ -245,6 +245,47 @@ func (t *RoomserverInternalAPITrace) RemoveRoomAlias(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *RoomserverInternalAPITrace) QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error {
|
||||||
|
err := t.Impl.QueryCurrentState(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).WithError(err).Infof("QueryCurrentState req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryRoomsForUser retrieves a list of room IDs matching the given query.
|
||||||
|
func (t *RoomserverInternalAPITrace) QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error {
|
||||||
|
err := t.Impl.QueryRoomsForUser(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).WithError(err).Infof("QueryRoomsForUser req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryBulkStateContent does a bulk query for state event content in the given rooms.
|
||||||
|
func (t *RoomserverInternalAPITrace) QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error {
|
||||||
|
err := t.Impl.QueryBulkStateContent(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).WithError(err).Infof("QueryBulkStateContent req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
|
||||||
|
func (t *RoomserverInternalAPITrace) QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error {
|
||||||
|
err := t.Impl.QuerySharedUsers(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).WithError(err).Infof("QuerySharedUsers req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryKnownUsers returns a list of users that we know about from our joined rooms.
|
||||||
|
func (t *RoomserverInternalAPITrace) QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error {
|
||||||
|
err := t.Impl.QueryKnownUsers(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).WithError(err).Infof("QueryKnownUsers req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
|
||||||
|
func (t *RoomserverInternalAPITrace) QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error {
|
||||||
|
err := t.Impl.QueryServerBannedFromRoom(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).WithError(err).Infof("QueryServerBannedFromRoom req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func js(thing interface{}) string {
|
func js(thing interface{}) string {
|
||||||
b, err := json.Marshal(thing)
|
b, err := json.Marshal(thing)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -83,5 +83,4 @@ type InputRoomEventsRequest struct {
|
||||||
|
|
||||||
// InputRoomEventsResponse is a response to InputRoomEvents
|
// InputRoomEventsResponse is a response to InputRoomEvents
|
||||||
type InputRoomEventsResponse struct {
|
type InputRoomEventsResponse struct {
|
||||||
EventID string `json:"event_id"`
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,11 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -225,3 +230,102 @@ type QueryPublishedRoomsResponse struct {
|
||||||
// The list of published rooms.
|
// The list of published rooms.
|
||||||
RoomIDs []string
|
RoomIDs []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type QuerySharedUsersRequest struct {
|
||||||
|
UserID string
|
||||||
|
ExcludeRoomIDs []string
|
||||||
|
IncludeRoomIDs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type QuerySharedUsersResponse struct {
|
||||||
|
UserIDsToCount map[string]int
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryRoomsForUserRequest struct {
|
||||||
|
UserID string
|
||||||
|
// The desired membership of the user. If this is the empty string then no rooms are returned.
|
||||||
|
WantMembership string
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryRoomsForUserResponse struct {
|
||||||
|
RoomIDs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryBulkStateContentRequest struct {
|
||||||
|
// Returns state events in these rooms
|
||||||
|
RoomIDs []string
|
||||||
|
// If true, treats the '*' StateKey as "all state events of this type" rather than a literal value of '*'
|
||||||
|
AllowWildcards bool
|
||||||
|
// The state events to return. Only a small subset of tuples are allowed in this request as only certain events
|
||||||
|
// have their content fields extracted. Specifically, the tuple Type must be one of:
|
||||||
|
// m.room.avatar
|
||||||
|
// m.room.create
|
||||||
|
// m.room.canonical_alias
|
||||||
|
// m.room.guest_access
|
||||||
|
// m.room.history_visibility
|
||||||
|
// m.room.join_rules
|
||||||
|
// m.room.member
|
||||||
|
// m.room.name
|
||||||
|
// m.room.topic
|
||||||
|
// Any other tuple type will result in the query failing.
|
||||||
|
StateTuples []gomatrixserverlib.StateKeyTuple
|
||||||
|
}
|
||||||
|
type QueryBulkStateContentResponse struct {
|
||||||
|
// map of room ID -> tuple -> content_value
|
||||||
|
Rooms map[string]map[gomatrixserverlib.StateKeyTuple]string
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryCurrentStateRequest struct {
|
||||||
|
RoomID string
|
||||||
|
StateTuples []gomatrixserverlib.StateKeyTuple
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryCurrentStateResponse struct {
|
||||||
|
StateEvents map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryKnownUsersRequest struct {
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
SearchString string `json:"search_string"`
|
||||||
|
Limit int `json:"limit"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryKnownUsersResponse struct {
|
||||||
|
Users []authtypes.FullyQualifiedProfile `json:"profiles"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryServerBannedFromRoomRequest struct {
|
||||||
|
ServerName gomatrixserverlib.ServerName `json:"server_name"`
|
||||||
|
RoomID string `json:"room_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryServerBannedFromRoomResponse struct {
|
||||||
|
Banned bool `json:"banned"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON stringifies the StateKeyTuple keys so they can be sent over the wire in HTTP API mode.
|
||||||
|
func (r *QueryCurrentStateResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
se := make(map[string]*gomatrixserverlib.HeaderedEvent, len(r.StateEvents))
|
||||||
|
for k, v := range r.StateEvents {
|
||||||
|
// use 0x1F (unit separator) as the delimiter between type/state key,
|
||||||
|
se[fmt.Sprintf("%s\x1F%s", k.EventType, k.StateKey)] = v
|
||||||
|
}
|
||||||
|
return json.Marshal(se)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error {
|
||||||
|
res := make(map[string]*gomatrixserverlib.HeaderedEvent)
|
||||||
|
err := json.Unmarshal(data, &res)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent, len(res))
|
||||||
|
for k, v := range res {
|
||||||
|
fields := strings.Split(k, "\x1F")
|
||||||
|
r.StateEvents[gomatrixserverlib.StateKeyTuple{
|
||||||
|
EventType: fields[0],
|
||||||
|
StateKey: fields[1],
|
||||||
|
}] = v
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -26,7 +26,7 @@ import (
|
||||||
func SendEvents(
|
func SendEvents(
|
||||||
ctx context.Context, rsAPI RoomserverInternalAPI, events []gomatrixserverlib.HeaderedEvent,
|
ctx context.Context, rsAPI RoomserverInternalAPI, events []gomatrixserverlib.HeaderedEvent,
|
||||||
sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID,
|
sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID,
|
||||||
) (string, error) {
|
) error {
|
||||||
ires := make([]InputRoomEvent, len(events))
|
ires := make([]InputRoomEvent, len(events))
|
||||||
for i, event := range events {
|
for i, event := range events {
|
||||||
ires[i] = InputRoomEvent{
|
ires[i] = InputRoomEvent{
|
||||||
|
@ -77,19 +77,16 @@ func SendEventWithState(
|
||||||
StateEventIDs: stateEventIDs,
|
StateEventIDs: stateEventIDs,
|
||||||
})
|
})
|
||||||
|
|
||||||
_, err = SendInputRoomEvents(ctx, rsAPI, ires)
|
return SendInputRoomEvents(ctx, rsAPI, ires)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendInputRoomEvents to the roomserver.
|
// SendInputRoomEvents to the roomserver.
|
||||||
func SendInputRoomEvents(
|
func SendInputRoomEvents(
|
||||||
ctx context.Context, rsAPI RoomserverInternalAPI, ires []InputRoomEvent,
|
ctx context.Context, rsAPI RoomserverInternalAPI, ires []InputRoomEvent,
|
||||||
) (eventID string, err error) {
|
) error {
|
||||||
request := InputRoomEventsRequest{InputRoomEvents: ires}
|
request := InputRoomEventsRequest{InputRoomEvents: ires}
|
||||||
var response InputRoomEventsResponse
|
var response InputRoomEventsResponse
|
||||||
err = rsAPI.InputRoomEvents(ctx, &request, &response)
|
return rsAPI.InputRoomEvents(ctx, &request, &response)
|
||||||
eventID = response.EventID
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendInvite event to the roomserver.
|
// SendInvite event to the roomserver.
|
||||||
|
@ -136,3 +133,102 @@ func GetEvent(ctx context.Context, rsAPI RoomserverInternalAPI, eventID string)
|
||||||
}
|
}
|
||||||
return &res.Events[0]
|
return &res.Events[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetStateEvent returns the current state event in the room or nil.
|
||||||
|
func GetStateEvent(ctx context.Context, rsAPI RoomserverInternalAPI, roomID string, tuple gomatrixserverlib.StateKeyTuple) *gomatrixserverlib.HeaderedEvent {
|
||||||
|
var res QueryCurrentStateResponse
|
||||||
|
err := rsAPI.QueryCurrentState(ctx, &QueryCurrentStateRequest{
|
||||||
|
RoomID: roomID,
|
||||||
|
StateTuples: []gomatrixserverlib.StateKeyTuple{tuple},
|
||||||
|
}, &res)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("Failed to QueryCurrentState")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ev, ok := res.StateEvents[tuple]
|
||||||
|
if ok {
|
||||||
|
return ev
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsServerBannedFromRoom returns whether the server is banned from a room by server ACLs.
|
||||||
|
func IsServerBannedFromRoom(ctx context.Context, rsAPI RoomserverInternalAPI, roomID string, serverName gomatrixserverlib.ServerName) bool {
|
||||||
|
req := &QueryServerBannedFromRoomRequest{
|
||||||
|
ServerName: serverName,
|
||||||
|
RoomID: roomID,
|
||||||
|
}
|
||||||
|
res := &QueryServerBannedFromRoomResponse{}
|
||||||
|
if err := rsAPI.QueryServerBannedFromRoom(ctx, req, res); err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("Failed to QueryServerBannedFromRoom")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return res.Banned
|
||||||
|
}
|
||||||
|
|
||||||
|
// PopulatePublicRooms extracts PublicRoom information for all the provided room IDs. The IDs are not checked to see if they are visible in the
|
||||||
|
// published room directory.
|
||||||
|
// due to lots of switches
|
||||||
|
// nolint:gocyclo
|
||||||
|
func PopulatePublicRooms(ctx context.Context, roomIDs []string, rsAPI RoomserverInternalAPI) ([]gomatrixserverlib.PublicRoom, error) {
|
||||||
|
avatarTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.avatar", StateKey: ""}
|
||||||
|
nameTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.name", StateKey: ""}
|
||||||
|
canonicalTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomCanonicalAlias, StateKey: ""}
|
||||||
|
topicTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.topic", StateKey: ""}
|
||||||
|
guestTuple := gomatrixserverlib.StateKeyTuple{EventType: "m.room.guest_access", StateKey: ""}
|
||||||
|
visibilityTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomHistoryVisibility, StateKey: ""}
|
||||||
|
joinRuleTuple := gomatrixserverlib.StateKeyTuple{EventType: gomatrixserverlib.MRoomJoinRules, StateKey: ""}
|
||||||
|
|
||||||
|
var stateRes QueryBulkStateContentResponse
|
||||||
|
err := rsAPI.QueryBulkStateContent(ctx, &QueryBulkStateContentRequest{
|
||||||
|
RoomIDs: roomIDs,
|
||||||
|
AllowWildcards: true,
|
||||||
|
StateTuples: []gomatrixserverlib.StateKeyTuple{
|
||||||
|
nameTuple, canonicalTuple, topicTuple, guestTuple, visibilityTuple, joinRuleTuple, avatarTuple,
|
||||||
|
{EventType: gomatrixserverlib.MRoomMember, StateKey: "*"},
|
||||||
|
},
|
||||||
|
}, &stateRes)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("QueryBulkStateContent failed")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
chunk := make([]gomatrixserverlib.PublicRoom, len(roomIDs))
|
||||||
|
i := 0
|
||||||
|
for roomID, data := range stateRes.Rooms {
|
||||||
|
pub := gomatrixserverlib.PublicRoom{
|
||||||
|
RoomID: roomID,
|
||||||
|
}
|
||||||
|
joinCount := 0
|
||||||
|
var joinRule, guestAccess string
|
||||||
|
for tuple, contentVal := range data {
|
||||||
|
if tuple.EventType == gomatrixserverlib.MRoomMember && contentVal == "join" {
|
||||||
|
joinCount++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch tuple {
|
||||||
|
case avatarTuple:
|
||||||
|
pub.AvatarURL = contentVal
|
||||||
|
case nameTuple:
|
||||||
|
pub.Name = contentVal
|
||||||
|
case topicTuple:
|
||||||
|
pub.Topic = contentVal
|
||||||
|
case canonicalTuple:
|
||||||
|
pub.CanonicalAlias = contentVal
|
||||||
|
case visibilityTuple:
|
||||||
|
pub.WorldReadable = contentVal == "world_readable"
|
||||||
|
// need both of these to determine whether guests can join
|
||||||
|
case joinRuleTuple:
|
||||||
|
joinRule = contentVal
|
||||||
|
case guestTuple:
|
||||||
|
guestAccess = contentVal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if joinRule == gomatrixserverlib.Public && guestAccess == "can_join" {
|
||||||
|
pub.GuestCanJoin = true
|
||||||
|
}
|
||||||
|
pub.JoinedMembersCount = joinCount
|
||||||
|
chunk[i] = pub
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
return chunk, nil
|
||||||
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
@ -239,16 +240,19 @@ func (r *RoomserverInternalAPI) sendUpdatedAliasesEvent(
|
||||||
}
|
}
|
||||||
builder.AuthEvents = refs
|
builder.AuthEvents = refs
|
||||||
|
|
||||||
roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, roomID)
|
roomInfo, err := r.DB.RoomInfo(ctx, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if roomInfo == nil {
|
||||||
|
return fmt.Errorf("room %s does not exist", roomID)
|
||||||
|
}
|
||||||
|
|
||||||
// Build the event
|
// Build the event
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
event, err := builder.Build(
|
event, err := builder.Build(
|
||||||
now, r.Cfg.Matrix.ServerName, r.Cfg.Matrix.KeyID,
|
now, r.Cfg.Matrix.ServerName, r.Cfg.Matrix.KeyID,
|
||||||
r.Cfg.Matrix.PrivateKey, roomVersion,
|
r.Cfg.Matrix.PrivateKey, roomInfo.RoomVersion,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -257,7 +261,7 @@ func (r *RoomserverInternalAPI) sendUpdatedAliasesEvent(
|
||||||
// Create the request
|
// Create the request
|
||||||
ire := api.InputRoomEvent{
|
ire := api.InputRoomEvent{
|
||||||
Kind: api.KindNew,
|
Kind: api.KindNew,
|
||||||
Event: event.Headered(roomVersion),
|
Event: event.Headered(roomInfo.RoomVersion),
|
||||||
AuthEventIDs: event.AuthEventIDs(),
|
AuthEventIDs: event.AuthEventIDs(),
|
||||||
SendAsServer: serverName,
|
SendAsServer: serverName,
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,26 +1,129 @@
|
||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
"context"
|
||||||
|
|
||||||
"github.com/Shopify/sarama"
|
"github.com/Shopify/sarama"
|
||||||
fsAPI "github.com/matrix-org/dendrite/federationsender/api"
|
fsAPI "github.com/matrix-org/dendrite/federationsender/api"
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
"github.com/matrix-org/dendrite/internal/config"
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/acls"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/internal/input"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/internal/perform"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/internal/query"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI
|
// RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI
|
||||||
type RoomserverInternalAPI struct {
|
type RoomserverInternalAPI struct {
|
||||||
|
*input.Inputer
|
||||||
|
*query.Queryer
|
||||||
|
*perform.Inviter
|
||||||
|
*perform.Joiner
|
||||||
|
*perform.Leaver
|
||||||
|
*perform.Publisher
|
||||||
|
*perform.Backfiller
|
||||||
DB storage.Database
|
DB storage.Database
|
||||||
Cfg *config.RoomServer
|
Cfg *config.RoomServer
|
||||||
Producer sarama.SyncProducer
|
Producer sarama.SyncProducer
|
||||||
Cache caching.RoomServerCaches
|
Cache caching.RoomServerCaches
|
||||||
ServerName gomatrixserverlib.ServerName
|
ServerName gomatrixserverlib.ServerName
|
||||||
KeyRing gomatrixserverlib.JSONVerifier
|
KeyRing gomatrixserverlib.JSONVerifier
|
||||||
FedClient *gomatrixserverlib.FederationClient
|
|
||||||
OutputRoomEventTopic string // Kafka topic for new output room events
|
|
||||||
mutexes sync.Map // room ID -> *sync.Mutex, protects calls to processRoomEvent
|
|
||||||
fsAPI fsAPI.FederationSenderInternalAPI
|
fsAPI fsAPI.FederationSenderInternalAPI
|
||||||
|
OutputRoomEventTopic string // Kafka topic for new output room events
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRoomserverAPI(
|
||||||
|
cfg *config.RoomServer, roomserverDB storage.Database, producer sarama.SyncProducer,
|
||||||
|
outputRoomEventTopic string, caches caching.RoomServerCaches,
|
||||||
|
keyRing gomatrixserverlib.JSONVerifier,
|
||||||
|
) *RoomserverInternalAPI {
|
||||||
|
a := &RoomserverInternalAPI{
|
||||||
|
DB: roomserverDB,
|
||||||
|
Cfg: cfg,
|
||||||
|
Cache: caches,
|
||||||
|
ServerName: cfg.Matrix.ServerName,
|
||||||
|
KeyRing: keyRing,
|
||||||
|
Queryer: &query.Queryer{
|
||||||
|
DB: roomserverDB,
|
||||||
|
Cache: caches,
|
||||||
|
ServerACLs: acls.NewServerACLs(roomserverDB),
|
||||||
|
},
|
||||||
|
Inputer: &input.Inputer{
|
||||||
|
DB: roomserverDB,
|
||||||
|
OutputRoomEventTopic: outputRoomEventTopic,
|
||||||
|
Producer: producer,
|
||||||
|
ServerName: cfg.Matrix.ServerName,
|
||||||
|
},
|
||||||
|
// perform-er structs get initialised when we have a federation sender to use
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFederationSenderInputAPI passes in a federation sender input API reference
|
||||||
|
// so that we can avoid the chicken-and-egg problem of both the roomserver input API
|
||||||
|
// and the federation sender input API being interdependent.
|
||||||
|
func (r *RoomserverInternalAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) {
|
||||||
|
r.fsAPI = fsAPI
|
||||||
|
|
||||||
|
r.Inviter = &perform.Inviter{
|
||||||
|
DB: r.DB,
|
||||||
|
Cfg: r.Cfg,
|
||||||
|
FSAPI: r.fsAPI,
|
||||||
|
Inputer: r.Inputer,
|
||||||
|
}
|
||||||
|
r.Joiner = &perform.Joiner{
|
||||||
|
ServerName: r.Cfg.Matrix.ServerName,
|
||||||
|
Cfg: r.Cfg,
|
||||||
|
DB: r.DB,
|
||||||
|
FSAPI: r.fsAPI,
|
||||||
|
Inputer: r.Inputer,
|
||||||
|
}
|
||||||
|
r.Leaver = &perform.Leaver{
|
||||||
|
Cfg: r.Cfg,
|
||||||
|
DB: r.DB,
|
||||||
|
FSAPI: r.fsAPI,
|
||||||
|
Inputer: r.Inputer,
|
||||||
|
}
|
||||||
|
r.Publisher = &perform.Publisher{
|
||||||
|
DB: r.DB,
|
||||||
|
}
|
||||||
|
r.Backfiller = &perform.Backfiller{
|
||||||
|
ServerName: r.ServerName,
|
||||||
|
DB: r.DB,
|
||||||
|
FSAPI: r.fsAPI,
|
||||||
|
KeyRing: r.KeyRing,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RoomserverInternalAPI) PerformInvite(
|
||||||
|
ctx context.Context,
|
||||||
|
req *api.PerformInviteRequest,
|
||||||
|
res *api.PerformInviteResponse,
|
||||||
|
) error {
|
||||||
|
outputEvents, err := r.Inviter.PerformInvite(ctx, req, res)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(outputEvents) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return r.WriteOutputEvents(req.Event.RoomID(), outputEvents)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RoomserverInternalAPI) PerformLeave(
|
||||||
|
ctx context.Context,
|
||||||
|
req *api.PerformLeaveRequest,
|
||||||
|
res *api.PerformLeaveResponse,
|
||||||
|
) error {
|
||||||
|
outputEvents, err := r.Leaver.PerformLeave(ctx, req, res)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(outputEvents) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return r.WriteOutputEvents(req.RoomID, outputEvents)
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package internal
|
package helpers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -23,9 +23,9 @@ import (
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
// checkAuthEvents checks that the event passes authentication checks
|
// CheckAuthEvents checks that the event passes authentication checks
|
||||||
// Returns the numeric IDs for the auth events.
|
// Returns the numeric IDs for the auth events.
|
||||||
func checkAuthEvents(
|
func CheckAuthEvents(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
db storage.Database,
|
db storage.Database,
|
||||||
event gomatrixserverlib.HeaderedEvent,
|
event gomatrixserverlib.HeaderedEvent,
|
||||||
|
@ -63,7 +63,7 @@ func checkAuthEvents(
|
||||||
type authEvents struct {
|
type authEvents struct {
|
||||||
stateKeyNIDMap map[string]types.EventStateKeyNID
|
stateKeyNIDMap map[string]types.EventStateKeyNID
|
||||||
state stateEntryMap
|
state stateEntryMap
|
||||||
events eventMap
|
events EventMap
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create implements gomatrixserverlib.AuthEventProvider
|
// Create implements gomatrixserverlib.AuthEventProvider
|
||||||
|
@ -99,7 +99,7 @@ func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
event, ok := ae.events.lookup(eventNID)
|
event, ok := ae.events.Lookup(eventNID)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -118,7 +118,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
event, ok := ae.events.lookup(eventNID)
|
event, ok := ae.events.Lookup(eventNID)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -224,10 +224,10 @@ func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.Even
|
||||||
|
|
||||||
// Map from numeric event ID to event.
|
// Map from numeric event ID to event.
|
||||||
// Implemented using binary search on a sorted array.
|
// Implemented using binary search on a sorted array.
|
||||||
type eventMap []types.Event
|
type EventMap []types.Event
|
||||||
|
|
||||||
// lookup an entry in the event map.
|
// lookup an entry in the event map.
|
||||||
func (m eventMap) lookup(eventNID types.EventNID) (event *types.Event, ok bool) {
|
func (m EventMap) Lookup(eventNID types.EventNID) (event *types.Event, ok bool) {
|
||||||
// Since the list is sorted we can implement this using binary search.
|
// Since the list is sorted we can implement this using binary search.
|
||||||
// This is faster than using a hash map.
|
// This is faster than using a hash map.
|
||||||
// We don't have to worry about pathological cases because the keys are fixed
|
// We don't have to worry about pathological cases because the keys are fixed
|
|
@ -12,7 +12,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package internal
|
package helpers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -95,7 +95,7 @@ func TestStateEntryMap(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEventMap(t *testing.T) {
|
func TestEventMap(t *testing.T) {
|
||||||
events := eventMap([]types.Event{
|
events := EventMap([]types.Event{
|
||||||
{EventNID: 1},
|
{EventNID: 1},
|
||||||
{EventNID: 2},
|
{EventNID: 2},
|
||||||
{EventNID: 3},
|
{EventNID: 3},
|
||||||
|
@ -123,7 +123,7 @@ func TestEventMap(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
gotEvent, gotOK := events.lookup(testCase.inputEventNID)
|
gotEvent, gotOK := events.Lookup(testCase.inputEventNID)
|
||||||
if testCase.wantOK != gotOK {
|
if testCase.wantOK != gotOK {
|
||||||
t.Fatalf("eventMap lookup(%v): want ok to be %v, got %v", testCase.inputEventNID, testCase.wantOK, gotOK)
|
t.Fatalf("eventMap lookup(%v): want ok to be %v, got %v", testCase.inputEventNID, testCase.wantOK, gotOK)
|
||||||
}
|
}
|
379
roomserver/internal/helpers/helpers.go
Normal file
379
roomserver/internal/helpers/helpers.go
Normal file
|
@ -0,0 +1,379 @@
|
||||||
|
package helpers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/auth"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/state"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO: temporary package which has helper functions used by both internal/perform packages.
|
||||||
|
// Move these to a more sensible place.
|
||||||
|
|
||||||
|
func UpdateToInviteMembership(
|
||||||
|
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
|
||||||
|
roomVersion gomatrixserverlib.RoomVersion,
|
||||||
|
) ([]api.OutputEvent, error) {
|
||||||
|
// We may have already sent the invite to the user, either because we are
|
||||||
|
// reprocessing this event, or because the we received this invite from a
|
||||||
|
// remote server via the federation invite API. In those cases we don't need
|
||||||
|
// to send the event.
|
||||||
|
needsSending, err := mu.SetToInvite(*add)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if needsSending {
|
||||||
|
// We notify the consumers using a special event even though we will
|
||||||
|
// notify them about the change in current state as part of the normal
|
||||||
|
// room event stream. This ensures that the consumers only have to
|
||||||
|
// consider a single stream of events when determining whether a user
|
||||||
|
// is invited, rather than having to combine multiple streams themselves.
|
||||||
|
onie := api.OutputNewInviteEvent{
|
||||||
|
Event: add.Headered(roomVersion),
|
||||||
|
RoomVersion: roomVersion,
|
||||||
|
}
|
||||||
|
updates = append(updates, api.OutputEvent{
|
||||||
|
Type: api.OutputTypeNewInviteEvent,
|
||||||
|
NewInviteEvent: &onie,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return updates, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) {
|
||||||
|
info, err := db.RoomInfo(ctx, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if info == nil {
|
||||||
|
return false, fmt.Errorf("unknown room %s", roomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
eventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
events, err := db.Events(ctx, eventNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
gmslEvents := make([]gomatrixserverlib.Event, len(events))
|
||||||
|
for i := range events {
|
||||||
|
gmslEvents[i] = events[i].Event
|
||||||
|
}
|
||||||
|
return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsInvitePending(
|
||||||
|
ctx context.Context, db storage.Database,
|
||||||
|
roomID, userID string,
|
||||||
|
) (bool, string, string, error) {
|
||||||
|
// Look up the room NID for the supplied room ID.
|
||||||
|
info, err := db.RoomInfo(ctx, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return false, "", "", fmt.Errorf("r.DB.RoomInfo: %w", err)
|
||||||
|
}
|
||||||
|
if info == nil {
|
||||||
|
return false, "", "", fmt.Errorf("cannot get RoomInfo: unknown room ID %s", roomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up the state key NID for the supplied user ID.
|
||||||
|
targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID})
|
||||||
|
if err != nil {
|
||||||
|
return false, "", "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err)
|
||||||
|
}
|
||||||
|
targetUserNID, targetUserFound := targetUserNIDs[userID]
|
||||||
|
if !targetUserFound {
|
||||||
|
return false, "", "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Let's see if we have an event active for the user in the room. If
|
||||||
|
// we do then it will contain a server name that we can direct the
|
||||||
|
// send_leave to.
|
||||||
|
senderUserNIDs, eventIDs, err := db.GetInvitesForUser(ctx, info.RoomNID, targetUserNID)
|
||||||
|
if err != nil {
|
||||||
|
return false, "", "", fmt.Errorf("r.DB.GetInvitesForUser: %w", err)
|
||||||
|
}
|
||||||
|
if len(senderUserNIDs) == 0 {
|
||||||
|
return false, "", "", nil
|
||||||
|
}
|
||||||
|
userNIDToEventID := make(map[types.EventStateKeyNID]string)
|
||||||
|
for i, nid := range senderUserNIDs {
|
||||||
|
userNIDToEventID[nid] = eventIDs[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up the user ID from the NID.
|
||||||
|
senderUsers, err := db.EventStateKeys(ctx, senderUserNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return false, "", "", fmt.Errorf("r.DB.EventStateKeys: %w", err)
|
||||||
|
}
|
||||||
|
if len(senderUsers) == 0 {
|
||||||
|
return false, "", "", fmt.Errorf("no senderUsers")
|
||||||
|
}
|
||||||
|
|
||||||
|
senderUser, senderUserFound := senderUsers[senderUserNIDs[0]]
|
||||||
|
if !senderUserFound {
|
||||||
|
return false, "", "", fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, senderUser, userNIDToEventID[senderUserNIDs[0]], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMembershipsAtState filters the state events to
|
||||||
|
// only keep the "m.room.member" events with a "join" membership. These events are returned.
|
||||||
|
// Returns an error if there was an issue fetching the events.
|
||||||
|
func GetMembershipsAtState(
|
||||||
|
ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool,
|
||||||
|
) ([]types.Event, error) {
|
||||||
|
|
||||||
|
var eventNIDs []types.EventNID
|
||||||
|
for _, entry := range stateEntries {
|
||||||
|
// Filter the events to retrieve to only keep the membership events
|
||||||
|
if entry.EventTypeNID == types.MRoomMemberNID {
|
||||||
|
eventNIDs = append(eventNIDs, entry.EventNID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get all of the events in this state
|
||||||
|
stateEvents, err := db.Events(ctx, eventNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !joinedOnly {
|
||||||
|
return stateEvents, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter the events to only keep the "join" membership events
|
||||||
|
var events []types.Event
|
||||||
|
for _, event := range stateEvents {
|
||||||
|
membership, err := event.Membership()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if membership == gomatrixserverlib.Join {
|
||||||
|
events = append(events, event)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return events, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func StateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) {
|
||||||
|
roomState := state.NewStateResolution(db, info)
|
||||||
|
// Lookup the event NID
|
||||||
|
eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
eventIDs := []string{eIDs[eventNID]}
|
||||||
|
|
||||||
|
prevState, err := db.StateAtEventIDs(ctx, eventIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch the state as it was when this event was fired
|
||||||
|
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadEvents(
|
||||||
|
ctx context.Context, db storage.Database, eventNIDs []types.EventNID,
|
||||||
|
) ([]gomatrixserverlib.Event, error) {
|
||||||
|
stateEvents, err := db.Events(ctx, eventNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]gomatrixserverlib.Event, len(stateEvents))
|
||||||
|
for i := range stateEvents {
|
||||||
|
result[i] = stateEvents[i].Event
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadStateEvents(
|
||||||
|
ctx context.Context, db storage.Database, stateEntries []types.StateEntry,
|
||||||
|
) ([]gomatrixserverlib.Event, error) {
|
||||||
|
eventNIDs := make([]types.EventNID, len(stateEntries))
|
||||||
|
for i := range stateEntries {
|
||||||
|
eventNIDs[i] = stateEntries[i].EventNID
|
||||||
|
}
|
||||||
|
return LoadEvents(ctx, db, eventNIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CheckServerAllowedToSeeEvent(
|
||||||
|
ctx context.Context, db storage.Database, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
|
||||||
|
) (bool, error) {
|
||||||
|
roomState := state.NewStateResolution(db, info)
|
||||||
|
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: We probably want to make it so that we don't have to pull
|
||||||
|
// out all the state if possible.
|
||||||
|
stateAtEvent, err := LoadStateEvents(ctx, db, stateEntries)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Remove this when we have tests to assert correctness of this function
|
||||||
|
// nolint:gocyclo
|
||||||
|
func ScanEventTree(
|
||||||
|
ctx context.Context, db storage.Database, info types.RoomInfo, front []string, visited map[string]bool, limit int,
|
||||||
|
serverName gomatrixserverlib.ServerName,
|
||||||
|
) ([]types.EventNID, error) {
|
||||||
|
var resultNIDs []types.EventNID
|
||||||
|
var err error
|
||||||
|
var allowed bool
|
||||||
|
var events []types.Event
|
||||||
|
var next []string
|
||||||
|
var pre string
|
||||||
|
|
||||||
|
// TODO: add tests for this function to ensure it meets the contract that callers expect (and doc what that is supposed to be)
|
||||||
|
// Currently, callers like PerformBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing
|
||||||
|
// so means that the events in that map will NOT be returned from this function. That is not currently true, resulting in
|
||||||
|
// duplicate events being sent in response to /backfill requests.
|
||||||
|
initialIgnoreList := make(map[string]bool, len(visited))
|
||||||
|
for k, v := range visited {
|
||||||
|
initialIgnoreList[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
resultNIDs = make([]types.EventNID, 0, limit)
|
||||||
|
|
||||||
|
var checkedServerInRoom bool
|
||||||
|
var isServerInRoom bool
|
||||||
|
|
||||||
|
// Loop through the event IDs to retrieve the requested events and go
|
||||||
|
// through the whole tree (up to the provided limit) using the events'
|
||||||
|
// "prev_event" key.
|
||||||
|
BFSLoop:
|
||||||
|
for len(front) > 0 {
|
||||||
|
// Prevent unnecessary allocations: reset the slice only when not empty.
|
||||||
|
if len(next) > 0 {
|
||||||
|
next = make([]string, 0)
|
||||||
|
}
|
||||||
|
// Retrieve the events to process from the database.
|
||||||
|
events, err = db.EventsFromIDs(ctx, front)
|
||||||
|
if err != nil {
|
||||||
|
return resultNIDs, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !checkedServerInRoom && len(events) > 0 {
|
||||||
|
// It's nasty that we have to extract the room ID from an event, but many federation requests
|
||||||
|
// only talk in event IDs, no room IDs at all (!!!)
|
||||||
|
ev := events[0]
|
||||||
|
isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, serverName, ev.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.")
|
||||||
|
}
|
||||||
|
checkedServerInRoom = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ev := range events {
|
||||||
|
// Break out of the loop if the provided limit is reached.
|
||||||
|
if len(resultNIDs) == limit {
|
||||||
|
break BFSLoop
|
||||||
|
}
|
||||||
|
|
||||||
|
if !initialIgnoreList[ev.EventID()] {
|
||||||
|
// Update the list of events to retrieve.
|
||||||
|
resultNIDs = append(resultNIDs, ev.EventNID)
|
||||||
|
}
|
||||||
|
// Loop through the event's parents.
|
||||||
|
for _, pre = range ev.PrevEventIDs() {
|
||||||
|
// Only add an event to the list of next events to process if it
|
||||||
|
// hasn't been seen before.
|
||||||
|
if !visited[pre] {
|
||||||
|
visited[pre] = true
|
||||||
|
allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, pre, serverName, isServerInRoom)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
|
||||||
|
"Error checking if allowed to see event",
|
||||||
|
)
|
||||||
|
return resultNIDs, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the event hasn't been seen before and the HS
|
||||||
|
// requesting to retrieve it is allowed to do so, add it to
|
||||||
|
// the list of events to retrieve.
|
||||||
|
if allowed {
|
||||||
|
next = append(next, pre)
|
||||||
|
} else {
|
||||||
|
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).Info("Not allowed to see event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Repeat the same process with the parent events we just processed.
|
||||||
|
front = next
|
||||||
|
}
|
||||||
|
|
||||||
|
return resultNIDs, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func QueryLatestEventsAndState(
|
||||||
|
ctx context.Context, db storage.Database,
|
||||||
|
request *api.QueryLatestEventsAndStateRequest,
|
||||||
|
response *api.QueryLatestEventsAndStateResponse,
|
||||||
|
) error {
|
||||||
|
roomInfo, err := db.RoomInfo(ctx, request.RoomID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if roomInfo == nil || roomInfo.IsStub {
|
||||||
|
response.RoomExists = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
roomState := state.NewStateResolution(db, *roomInfo)
|
||||||
|
response.RoomExists = true
|
||||||
|
response.RoomVersion = roomInfo.RoomVersion
|
||||||
|
|
||||||
|
var currentStateSnapshotNID types.StateSnapshotNID
|
||||||
|
response.LatestEvents, currentStateSnapshotNID, response.Depth, err =
|
||||||
|
db.LatestEventIDs(ctx, roomInfo.RoomNID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var stateEntries []types.StateEntry
|
||||||
|
if len(request.StateToFetch) == 0 {
|
||||||
|
// Look up all room state.
|
||||||
|
stateEntries, err = roomState.LoadStateAtSnapshot(
|
||||||
|
ctx, currentStateSnapshotNID,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
// Look up the current state for the requested tuples.
|
||||||
|
stateEntries, err = roomState.LoadStateAtSnapshotForStringTuples(
|
||||||
|
ctx, currentStateSnapshotNID, request.StateToFetch,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
stateEvents, err := LoadStateEvents(ctx, db, stateEntries)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, event := range stateEvents {
|
||||||
|
response.StateEvents = append(response.StateEvents, event.Headered(roomInfo.RoomVersion))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -1,89 +0,0 @@
|
||||||
// Copyright 2017 Vector Creations Ltd
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
// Package input contains the code processes new room events
|
|
||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/Shopify/sarama"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
fsAPI "github.com/matrix-org/dendrite/federationsender/api"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SetFederationSenderInputAPI passes in a federation sender input API reference
|
|
||||||
// so that we can avoid the chicken-and-egg problem of both the roomserver input API
|
|
||||||
// and the federation sender input API being interdependent.
|
|
||||||
func (r *RoomserverInternalAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) {
|
|
||||||
r.fsAPI = fsAPI
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteOutputEvents implements OutputRoomEventWriter
|
|
||||||
func (r *RoomserverInternalAPI) WriteOutputEvents(roomID string, updates []api.OutputEvent) error {
|
|
||||||
messages := make([]*sarama.ProducerMessage, len(updates))
|
|
||||||
for i := range updates {
|
|
||||||
value, err := json.Marshal(updates[i])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
logger := log.WithFields(log.Fields{
|
|
||||||
"room_id": roomID,
|
|
||||||
"type": updates[i].Type,
|
|
||||||
})
|
|
||||||
if updates[i].NewRoomEvent != nil {
|
|
||||||
logger = logger.WithFields(log.Fields{
|
|
||||||
"event_type": updates[i].NewRoomEvent.Event.Type(),
|
|
||||||
"event_id": updates[i].NewRoomEvent.Event.EventID(),
|
|
||||||
"adds_state": len(updates[i].NewRoomEvent.AddsStateEventIDs),
|
|
||||||
"removes_state": len(updates[i].NewRoomEvent.RemovesStateEventIDs),
|
|
||||||
"send_as_server": updates[i].NewRoomEvent.SendAsServer,
|
|
||||||
"sender": updates[i].NewRoomEvent.Event.Sender(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
logger.Infof("Producing to topic '%s'", r.OutputRoomEventTopic)
|
|
||||||
messages[i] = &sarama.ProducerMessage{
|
|
||||||
Topic: r.OutputRoomEventTopic,
|
|
||||||
Key: sarama.StringEncoder(roomID),
|
|
||||||
Value: sarama.ByteEncoder(value),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return r.Producer.SendMessages(messages)
|
|
||||||
}
|
|
||||||
|
|
||||||
// InputRoomEvents implements api.RoomserverInternalAPI
|
|
||||||
func (r *RoomserverInternalAPI) InputRoomEvents(
|
|
||||||
ctx context.Context,
|
|
||||||
request *api.InputRoomEventsRequest,
|
|
||||||
response *api.InputRoomEventsResponse,
|
|
||||||
) (err error) {
|
|
||||||
for i, e := range request.InputRoomEvents {
|
|
||||||
roomID := "global"
|
|
||||||
if r.DB.SupportsConcurrentRoomInputs() {
|
|
||||||
roomID = e.Event.RoomID()
|
|
||||||
}
|
|
||||||
mutex, _ := r.mutexes.LoadOrStore(roomID, &sync.Mutex{})
|
|
||||||
mutex.(*sync.Mutex).Lock()
|
|
||||||
if response.EventID, err = r.processRoomEvent(ctx, request.InputRoomEvents[i]); err != nil {
|
|
||||||
mutex.(*sync.Mutex).Unlock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
mutex.(*sync.Mutex).Unlock()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
157
roomserver/internal/input/input.go
Normal file
157
roomserver/internal/input/input.go
Normal file
|
@ -0,0 +1,157 @@
|
||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Package input contains the code processes new room events
|
||||||
|
package input
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Shopify/sarama"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"go.uber.org/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Inputer struct {
|
||||||
|
DB storage.Database
|
||||||
|
Producer sarama.SyncProducer
|
||||||
|
ServerName gomatrixserverlib.ServerName
|
||||||
|
OutputRoomEventTopic string
|
||||||
|
|
||||||
|
workers sync.Map // room ID -> *inputWorker
|
||||||
|
}
|
||||||
|
|
||||||
|
type inputTask struct {
|
||||||
|
ctx context.Context
|
||||||
|
event *api.InputRoomEvent
|
||||||
|
wg *sync.WaitGroup
|
||||||
|
err error // written back by worker, only safe to read when all tasks are done
|
||||||
|
}
|
||||||
|
|
||||||
|
type inputWorker struct {
|
||||||
|
r *Inputer
|
||||||
|
running atomic.Bool
|
||||||
|
input chan *inputTask
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *inputWorker) start() {
|
||||||
|
if !w.running.CAS(false, true) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer w.running.Store(false)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case task := <-w.input:
|
||||||
|
_, task.err = w.r.processRoomEvent(task.ctx, task.event)
|
||||||
|
task.wg.Done()
|
||||||
|
case <-time.After(time.Second * 5):
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteOutputEvents implements OutputRoomEventWriter
|
||||||
|
func (r *Inputer) WriteOutputEvents(roomID string, updates []api.OutputEvent) error {
|
||||||
|
messages := make([]*sarama.ProducerMessage, len(updates))
|
||||||
|
for i := range updates {
|
||||||
|
value, err := json.Marshal(updates[i])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
logger := log.WithFields(log.Fields{
|
||||||
|
"room_id": roomID,
|
||||||
|
"type": updates[i].Type,
|
||||||
|
})
|
||||||
|
if updates[i].NewRoomEvent != nil {
|
||||||
|
logger = logger.WithFields(log.Fields{
|
||||||
|
"event_type": updates[i].NewRoomEvent.Event.Type(),
|
||||||
|
"event_id": updates[i].NewRoomEvent.Event.EventID(),
|
||||||
|
"adds_state": len(updates[i].NewRoomEvent.AddsStateEventIDs),
|
||||||
|
"removes_state": len(updates[i].NewRoomEvent.RemovesStateEventIDs),
|
||||||
|
"send_as_server": updates[i].NewRoomEvent.SendAsServer,
|
||||||
|
"sender": updates[i].NewRoomEvent.Event.Sender(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
logger.Infof("Producing to topic '%s'", r.OutputRoomEventTopic)
|
||||||
|
messages[i] = &sarama.ProducerMessage{
|
||||||
|
Topic: r.OutputRoomEventTopic,
|
||||||
|
Key: sarama.StringEncoder(roomID),
|
||||||
|
Value: sarama.ByteEncoder(value),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r.Producer.SendMessages(messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InputRoomEvents implements api.RoomserverInternalAPI
|
||||||
|
func (r *Inputer) InputRoomEvents(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.InputRoomEventsRequest,
|
||||||
|
response *api.InputRoomEventsResponse,
|
||||||
|
) error {
|
||||||
|
// Create a wait group. Each task that we dispatch will call Done on
|
||||||
|
// this wait group so that we know when all of our events have been
|
||||||
|
// processed.
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
wg.Add(len(request.InputRoomEvents))
|
||||||
|
tasks := make([]*inputTask, len(request.InputRoomEvents))
|
||||||
|
|
||||||
|
for i, e := range request.InputRoomEvents {
|
||||||
|
// Work out if we are running per-room workers or if we're just doing
|
||||||
|
// it on a global basis (e.g. SQLite).
|
||||||
|
roomID := "global"
|
||||||
|
if r.DB.SupportsConcurrentRoomInputs() {
|
||||||
|
roomID = e.Event.RoomID()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up the worker, or create it if it doesn't exist. This channel
|
||||||
|
// is buffered to reduce the chance that we'll be blocked by another
|
||||||
|
// room - the channel will be quite small as it's just pointer types.
|
||||||
|
w, _ := r.workers.LoadOrStore(roomID, &inputWorker{
|
||||||
|
r: r,
|
||||||
|
input: make(chan *inputTask, 10),
|
||||||
|
})
|
||||||
|
worker := w.(*inputWorker)
|
||||||
|
|
||||||
|
// Create a task. This contains the input event and a reference to
|
||||||
|
// the wait group, so that the worker can notify us when this specific
|
||||||
|
// task has been finished.
|
||||||
|
tasks[i] = &inputTask{
|
||||||
|
ctx: ctx,
|
||||||
|
event: &request.InputRoomEvents[i],
|
||||||
|
wg: wg,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the task to the worker.
|
||||||
|
go worker.start()
|
||||||
|
worker.input <- tasks[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all of the workers to return results about our tasks.
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// If any of the tasks returned an error, we should probably report
|
||||||
|
// that back to the caller.
|
||||||
|
for _, task := range tasks {
|
||||||
|
if task.err != nil {
|
||||||
|
return task.err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -14,7 +14,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package internal
|
package input
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -22,6 +22,7 @@ import (
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
|
||||||
"github.com/matrix-org/dendrite/roomserver/state"
|
"github.com/matrix-org/dendrite/roomserver/state"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
@ -35,9 +36,9 @@ import (
|
||||||
// state deltas when sending to kafka streams
|
// state deltas when sending to kafka streams
|
||||||
// TODO: Break up function - we should probably do transaction ID checks before calling this.
|
// TODO: Break up function - we should probably do transaction ID checks before calling this.
|
||||||
// nolint:gocyclo
|
// nolint:gocyclo
|
||||||
func (r *RoomserverInternalAPI) processRoomEvent(
|
func (r *Inputer) processRoomEvent(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
input api.InputRoomEvent,
|
input *api.InputRoomEvent,
|
||||||
) (eventID string, err error) {
|
) (eventID string, err error) {
|
||||||
// Parse and validate the event JSON
|
// Parse and validate the event JSON
|
||||||
headered := input.Event
|
headered := input.Event
|
||||||
|
@ -45,7 +46,7 @@ func (r *RoomserverInternalAPI) processRoomEvent(
|
||||||
|
|
||||||
// Check that the event passes authentication checks and work out
|
// Check that the event passes authentication checks and work out
|
||||||
// the numeric IDs for the auth events.
|
// the numeric IDs for the auth events.
|
||||||
authEventNIDs, err := checkAuthEvents(ctx, r.DB, headered, input.AuthEventIDs)
|
authEventNIDs, err := helpers.CheckAuthEvents(ctx, r.DB, headered, input.AuthEventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event")
|
logrus.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event")
|
||||||
return
|
return
|
||||||
|
@ -64,7 +65,7 @@ func (r *RoomserverInternalAPI) processRoomEvent(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store the event.
|
// Store the event.
|
||||||
roomNID, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs)
|
_, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("r.DB.StoreEvent: %w", err)
|
return "", fmt.Errorf("r.DB.StoreEvent: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -89,10 +90,18 @@ func (r *RoomserverInternalAPI) processRoomEvent(
|
||||||
return event.EventID(), nil
|
return event.EventID(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
roomInfo, err := r.DB.RoomInfo(ctx, event.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("r.DB.RoomInfo: %w", err)
|
||||||
|
}
|
||||||
|
if roomInfo == nil {
|
||||||
|
return "", fmt.Errorf("r.DB.RoomInfo missing for room %s", event.RoomID())
|
||||||
|
}
|
||||||
|
|
||||||
if stateAtEvent.BeforeStateSnapshotNID == 0 {
|
if stateAtEvent.BeforeStateSnapshotNID == 0 {
|
||||||
// We haven't calculated a state for this event yet.
|
// We haven't calculated a state for this event yet.
|
||||||
// Lets calculate one.
|
// Lets calculate one.
|
||||||
err = r.calculateAndSetState(ctx, input, roomNID, &stateAtEvent, event)
|
err = r.calculateAndSetState(ctx, input, *roomInfo, &stateAtEvent, event)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("r.calculateAndSetState: %w", err)
|
return "", fmt.Errorf("r.calculateAndSetState: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -100,7 +109,7 @@ func (r *RoomserverInternalAPI) processRoomEvent(
|
||||||
|
|
||||||
if err = r.updateLatestEvents(
|
if err = r.updateLatestEvents(
|
||||||
ctx, // context
|
ctx, // context
|
||||||
roomNID, // room NID to update
|
roomInfo, // room info for the room being updated
|
||||||
stateAtEvent, // state at event (below)
|
stateAtEvent, // state at event (below)
|
||||||
event, // event
|
event, // event
|
||||||
input.SendAsServer, // send as server
|
input.SendAsServer, // send as server
|
||||||
|
@ -132,22 +141,22 @@ func (r *RoomserverInternalAPI) processRoomEvent(
|
||||||
return event.EventID(), nil
|
return event.EventID(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) calculateAndSetState(
|
func (r *Inputer) calculateAndSetState(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
input api.InputRoomEvent,
|
input *api.InputRoomEvent,
|
||||||
roomNID types.RoomNID,
|
roomInfo types.RoomInfo,
|
||||||
stateAtEvent *types.StateAtEvent,
|
stateAtEvent *types.StateAtEvent,
|
||||||
event gomatrixserverlib.Event,
|
event gomatrixserverlib.Event,
|
||||||
) error {
|
) error {
|
||||||
var err error
|
var err error
|
||||||
roomState := state.NewStateResolution(r.DB)
|
roomState := state.NewStateResolution(r.DB, roomInfo)
|
||||||
|
|
||||||
if input.HasState {
|
if input.HasState {
|
||||||
// Check here if we think we're in the room already.
|
// Check here if we think we're in the room already.
|
||||||
stateAtEvent.Overwrite = true
|
stateAtEvent.Overwrite = true
|
||||||
var joinEventNIDs []types.EventNID
|
var joinEventNIDs []types.EventNID
|
||||||
// Request join memberships only for local users only.
|
// Request join memberships only for local users only.
|
||||||
if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true, true); err == nil {
|
if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil {
|
||||||
// If we have no local users that are joined to the room then any state about
|
// If we have no local users that are joined to the room then any state about
|
||||||
// the room that we have is quite possibly out of date. Therefore in that case
|
// the room that we have is quite possibly out of date. Therefore in that case
|
||||||
// we should overwrite it rather than merge it.
|
// we should overwrite it rather than merge it.
|
||||||
|
@ -161,14 +170,14 @@ func (r *RoomserverInternalAPI) calculateAndSetState(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil {
|
if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
stateAtEvent.Overwrite = false
|
stateAtEvent.Overwrite = false
|
||||||
|
|
||||||
// We haven't been told what the state at the event is so we need to calculate it from the prev_events
|
// We haven't been told what the state at the event is so we need to calculate it from the prev_events
|
||||||
if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event, roomNID); err != nil {
|
if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -14,7 +14,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package internal
|
package input
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
@ -47,15 +47,15 @@ import (
|
||||||
// 7 <----- latest
|
// 7 <----- latest
|
||||||
//
|
//
|
||||||
// Can only be called once at a time
|
// Can only be called once at a time
|
||||||
func (r *RoomserverInternalAPI) updateLatestEvents(
|
func (r *Inputer) updateLatestEvents(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
roomNID types.RoomNID,
|
roomInfo *types.RoomInfo,
|
||||||
stateAtEvent types.StateAtEvent,
|
stateAtEvent types.StateAtEvent,
|
||||||
event gomatrixserverlib.Event,
|
event gomatrixserverlib.Event,
|
||||||
sendAsServer string,
|
sendAsServer string,
|
||||||
transactionID *api.TransactionID,
|
transactionID *api.TransactionID,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID)
|
updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err)
|
return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -66,7 +66,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents(
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
api: r,
|
api: r,
|
||||||
updater: updater,
|
updater: updater,
|
||||||
roomNID: roomNID,
|
roomInfo: roomInfo,
|
||||||
stateAtEvent: stateAtEvent,
|
stateAtEvent: stateAtEvent,
|
||||||
event: event,
|
event: event,
|
||||||
sendAsServer: sendAsServer,
|
sendAsServer: sendAsServer,
|
||||||
|
@ -87,9 +87,9 @@ func (r *RoomserverInternalAPI) updateLatestEvents(
|
||||||
// when there are so many variables to pass around.
|
// when there are so many variables to pass around.
|
||||||
type latestEventsUpdater struct {
|
type latestEventsUpdater struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
api *RoomserverInternalAPI
|
api *Inputer
|
||||||
updater *shared.LatestEventsUpdater
|
updater *shared.LatestEventsUpdater
|
||||||
roomNID types.RoomNID
|
roomInfo *types.RoomInfo
|
||||||
stateAtEvent types.StateAtEvent
|
stateAtEvent types.StateAtEvent
|
||||||
event gomatrixserverlib.Event
|
event gomatrixserverlib.Event
|
||||||
transactionID *api.TransactionID
|
transactionID *api.TransactionID
|
||||||
|
@ -196,7 +196,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
||||||
return fmt.Errorf("u.api.WriteOutputEvents: %w", err)
|
return fmt.Errorf("u.api.WriteOutputEvents: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = u.updater.SetLatestEvents(u.roomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil {
|
if err = u.updater.SetLatestEvents(u.roomInfo.RoomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil {
|
||||||
return fmt.Errorf("u.updater.SetLatestEvents: %w", err)
|
return fmt.Errorf("u.updater.SetLatestEvents: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -209,7 +209,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
||||||
|
|
||||||
func (u *latestEventsUpdater) latestState() error {
|
func (u *latestEventsUpdater) latestState() error {
|
||||||
var err error
|
var err error
|
||||||
roomState := state.NewStateResolution(u.api.DB)
|
roomState := state.NewStateResolution(u.api.DB, *u.roomInfo)
|
||||||
|
|
||||||
// Get a list of the current latest events.
|
// Get a list of the current latest events.
|
||||||
latestStateAtEvents := make([]types.StateAtEvent, len(u.latest))
|
latestStateAtEvents := make([]types.StateAtEvent, len(u.latest))
|
||||||
|
@ -221,7 +221,7 @@ func (u *latestEventsUpdater) latestState() error {
|
||||||
// of the state after the events. The snapshot state will be resolved
|
// of the state after the events. The snapshot state will be resolved
|
||||||
// using the correct state resolution algorithm for the room.
|
// using the correct state resolution algorithm for the room.
|
||||||
u.newStateNID, err = roomState.CalculateAndStoreStateAfterEvents(
|
u.newStateNID, err = roomState.CalculateAndStoreStateAfterEvents(
|
||||||
u.ctx, u.roomNID, latestStateAtEvents,
|
u.ctx, latestStateAtEvents,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("roomState.CalculateAndStoreStateAfterEvents: %w", err)
|
return fmt.Errorf("roomState.CalculateAndStoreStateAfterEvents: %w", err)
|
||||||
|
@ -303,13 +303,8 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error)
|
||||||
latestEventIDs[i] = u.latest[i].EventID
|
latestEventIDs[i] = u.latest[i].EventID
|
||||||
}
|
}
|
||||||
|
|
||||||
roomVersion, err := u.api.DB.GetRoomVersionForRoom(u.ctx, u.event.RoomID())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ore := api.OutputNewRoomEvent{
|
ore := api.OutputNewRoomEvent{
|
||||||
Event: u.event.Headered(roomVersion),
|
Event: u.event.Headered(u.roomInfo.RoomVersion),
|
||||||
LastSentEventID: u.lastEventIDSent,
|
LastSentEventID: u.lastEventIDSent,
|
||||||
LatestEventIDs: latestEventIDs,
|
LatestEventIDs: latestEventIDs,
|
||||||
TransactionID: u.transactionID,
|
TransactionID: u.transactionID,
|
||||||
|
@ -337,7 +332,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error)
|
||||||
// include extra state events if they were added as nearly every downstream component will care about it
|
// include extra state events if they were added as nearly every downstream component will care about it
|
||||||
// and we'd rather not have them all hit QueryEventsByID at the same time!
|
// and we'd rather not have them all hit QueryEventsByID at the same time!
|
||||||
if len(ore.AddsStateEventIDs) > 0 {
|
if len(ore.AddsStateEventIDs) > 0 {
|
||||||
ore.AddStateEvents, err = u.extraEventsForIDs(roomVersion, ore.AddsStateEventIDs)
|
ore.AddStateEvents, err = u.extraEventsForIDs(u.roomInfo.RoomVersion, ore.AddsStateEventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to load add_state_events from db: %w", err)
|
return nil, fmt.Errorf("failed to load add_state_events from db: %w", err)
|
||||||
}
|
}
|
|
@ -12,13 +12,14 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package internal
|
package input
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
@ -28,7 +29,7 @@ import (
|
||||||
// user affected by a change in the current state of the room.
|
// user affected by a change in the current state of the room.
|
||||||
// Returns a list of output events to write to the kafka log to inform the
|
// Returns a list of output events to write to the kafka log to inform the
|
||||||
// consumers about the invites added or retired by the change in current state.
|
// consumers about the invites added or retired by the change in current state.
|
||||||
func (r *RoomserverInternalAPI) updateMemberships(
|
func (r *Inputer) updateMemberships(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
updater *shared.LatestEventsUpdater,
|
updater *shared.LatestEventsUpdater,
|
||||||
removed, added []types.StateEntry,
|
removed, added []types.StateEntry,
|
||||||
|
@ -59,13 +60,13 @@ func (r *RoomserverInternalAPI) updateMemberships(
|
||||||
var re *gomatrixserverlib.Event
|
var re *gomatrixserverlib.Event
|
||||||
targetUserNID := change.EventStateKeyNID
|
targetUserNID := change.EventStateKeyNID
|
||||||
if change.removedEventNID != 0 {
|
if change.removedEventNID != 0 {
|
||||||
ev, _ := eventMap(events).lookup(change.removedEventNID)
|
ev, _ := helpers.EventMap(events).Lookup(change.removedEventNID)
|
||||||
if ev != nil {
|
if ev != nil {
|
||||||
re = &ev.Event
|
re = &ev.Event
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if change.addedEventNID != 0 {
|
if change.addedEventNID != 0 {
|
||||||
ev, _ := eventMap(events).lookup(change.addedEventNID)
|
ev, _ := helpers.EventMap(events).Lookup(change.addedEventNID)
|
||||||
if ev != nil {
|
if ev != nil {
|
||||||
ae = &ev.Event
|
ae = &ev.Event
|
||||||
}
|
}
|
||||||
|
@ -77,7 +78,7 @@ func (r *RoomserverInternalAPI) updateMemberships(
|
||||||
return updates, nil
|
return updates, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) updateMembership(
|
func (r *Inputer) updateMembership(
|
||||||
updater *shared.LatestEventsUpdater,
|
updater *shared.LatestEventsUpdater,
|
||||||
targetUserNID types.EventStateKeyNID,
|
targetUserNID types.EventStateKeyNID,
|
||||||
remove, add *gomatrixserverlib.Event,
|
remove, add *gomatrixserverlib.Event,
|
||||||
|
@ -120,7 +121,7 @@ func (r *RoomserverInternalAPI) updateMembership(
|
||||||
|
|
||||||
switch newMembership {
|
switch newMembership {
|
||||||
case gomatrixserverlib.Invite:
|
case gomatrixserverlib.Invite:
|
||||||
return updateToInviteMembership(mu, add, updates, updater.RoomVersion())
|
return helpers.UpdateToInviteMembership(mu, add, updates, updater.RoomVersion())
|
||||||
case gomatrixserverlib.Join:
|
case gomatrixserverlib.Join:
|
||||||
return updateToJoinMembership(mu, add, updates)
|
return updateToJoinMembership(mu, add, updates)
|
||||||
case gomatrixserverlib.Leave, gomatrixserverlib.Ban:
|
case gomatrixserverlib.Leave, gomatrixserverlib.Ban:
|
||||||
|
@ -132,45 +133,15 @@ func (r *RoomserverInternalAPI) updateMembership(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) isLocalTarget(event *gomatrixserverlib.Event) bool {
|
func (r *Inputer) isLocalTarget(event *gomatrixserverlib.Event) bool {
|
||||||
isTargetLocalUser := false
|
isTargetLocalUser := false
|
||||||
if statekey := event.StateKey(); statekey != nil {
|
if statekey := event.StateKey(); statekey != nil {
|
||||||
_, domain, _ := gomatrixserverlib.SplitID('@', *statekey)
|
_, domain, _ := gomatrixserverlib.SplitID('@', *statekey)
|
||||||
isTargetLocalUser = domain == r.Cfg.Matrix.ServerName
|
isTargetLocalUser = domain == r.ServerName
|
||||||
}
|
}
|
||||||
return isTargetLocalUser
|
return isTargetLocalUser
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateToInviteMembership(
|
|
||||||
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
|
|
||||||
roomVersion gomatrixserverlib.RoomVersion,
|
|
||||||
) ([]api.OutputEvent, error) {
|
|
||||||
// We may have already sent the invite to the user, either because we are
|
|
||||||
// reprocessing this event, or because the we received this invite from a
|
|
||||||
// remote server via the federation invite API. In those cases we don't need
|
|
||||||
// to send the event.
|
|
||||||
needsSending, err := mu.SetToInvite(*add)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if needsSending {
|
|
||||||
// We notify the consumers using a special event even though we will
|
|
||||||
// notify them about the change in current state as part of the normal
|
|
||||||
// room event stream. This ensures that the consumers only have to
|
|
||||||
// consider a single stream of events when determining whether a user
|
|
||||||
// is invited, rather than having to combine multiple streams themselves.
|
|
||||||
onie := api.OutputNewInviteEvent{
|
|
||||||
Event: add.Headered(roomVersion),
|
|
||||||
RoomVersion: roomVersion,
|
|
||||||
}
|
|
||||||
updates = append(updates, api.OutputEvent{
|
|
||||||
Type: api.OutputTypeNewInviteEvent,
|
|
||||||
NewInviteEvent: &onie,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return updates, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateToJoinMembership(
|
func updateToJoinMembership(
|
||||||
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
|
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
|
||||||
) ([]api.OutputEvent, error) {
|
) ([]api.OutputEvent, error) {
|
562
roomserver/internal/perform/perform_backfill.go
Normal file
562
roomserver/internal/perform/perform_backfill.go
Normal file
|
@ -0,0 +1,562 @@
|
||||||
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package perform
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
|
||||||
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/auth"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Backfiller struct {
|
||||||
|
ServerName gomatrixserverlib.ServerName
|
||||||
|
DB storage.Database
|
||||||
|
FSAPI federationSenderAPI.FederationSenderInternalAPI
|
||||||
|
KeyRing gomatrixserverlib.JSONVerifier
|
||||||
|
}
|
||||||
|
|
||||||
|
// PerformBackfill implements api.RoomServerQueryAPI
|
||||||
|
func (r *Backfiller) PerformBackfill(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.PerformBackfillRequest,
|
||||||
|
response *api.PerformBackfillResponse,
|
||||||
|
) error {
|
||||||
|
// if we are requesting the backfill then we need to do a federation hit
|
||||||
|
// TODO: we could be more sensible and fetch as many events we already have then request the rest
|
||||||
|
// which is what the syncapi does already.
|
||||||
|
if request.ServerName == r.ServerName {
|
||||||
|
return r.backfillViaFederation(ctx, request, response)
|
||||||
|
}
|
||||||
|
// someone else is requesting the backfill, try to service their request.
|
||||||
|
var err error
|
||||||
|
var front []string
|
||||||
|
|
||||||
|
// The limit defines the maximum number of events to retrieve, so it also
|
||||||
|
// defines the highest number of elements in the map below.
|
||||||
|
visited := make(map[string]bool, request.Limit)
|
||||||
|
|
||||||
|
// this will include these events which is what we want
|
||||||
|
front = request.PrevEventIDs()
|
||||||
|
|
||||||
|
info, err := r.DB.RoomInfo(ctx, request.RoomID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if info == nil || info.IsStub {
|
||||||
|
return fmt.Errorf("PerformBackfill: missing room info for room %s", request.RoomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan the event tree for events to send back.
|
||||||
|
resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, *info, front, visited, request.Limit, request.ServerName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve events from the list that was filled previously.
|
||||||
|
var loadedEvents []gomatrixserverlib.Event
|
||||||
|
loadedEvents, err = helpers.LoadEvents(ctx, r.DB, resultNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, event := range loadedEvents {
|
||||||
|
response.Events = append(response.Events, event.Headered(info.RoomVersion))
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.PerformBackfillRequest, res *api.PerformBackfillResponse) error {
|
||||||
|
info, err := r.DB.RoomInfo(ctx, req.RoomID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if info == nil || info.IsStub {
|
||||||
|
return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID)
|
||||||
|
}
|
||||||
|
requester := newBackfillRequester(r.DB, r.FSAPI, r.ServerName, req.BackwardsExtremities)
|
||||||
|
// Request 100 items regardless of what the query asks for.
|
||||||
|
// We don't want to go much higher than this.
|
||||||
|
// We can't honour exactly the limit as some sytests rely on requesting more for tests to pass
|
||||||
|
// (so we don't need to hit /state_ids which the test has no listener for)
|
||||||
|
// Specifically the test "Outbound federation can backfill events"
|
||||||
|
events, err := gomatrixserverlib.RequestBackfill(
|
||||||
|
ctx, requester,
|
||||||
|
r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events))
|
||||||
|
|
||||||
|
// persist these new events - auth checks have already been done
|
||||||
|
roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ev := range backfilledEventMap {
|
||||||
|
// now add state for these events
|
||||||
|
stateIDs, ok := requester.eventIDToBeforeStateIDs[ev.EventID()]
|
||||||
|
if !ok {
|
||||||
|
// this should be impossible as all events returned must have pass Step 5 of the PDU checks
|
||||||
|
// which requires a list of state IDs.
|
||||||
|
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to find state IDs for event which passed auth checks")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var entries []types.StateEntry
|
||||||
|
if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs); err != nil {
|
||||||
|
// attempt to fetch the missing events
|
||||||
|
r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs)
|
||||||
|
// try again
|
||||||
|
entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to get state entries for event")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var beforeStateSnapshotNID types.StateSnapshotNID
|
||||||
|
if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil {
|
||||||
|
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err = r.DB.SetState(ctx, ev.EventNID, beforeStateSnapshotNID); err != nil {
|
||||||
|
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist snapshot nid")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: update backwards extremities, as that should be moved from syncapi to roomserver at some point.
|
||||||
|
|
||||||
|
res.Events = events
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just
|
||||||
|
// best effort.
|
||||||
|
func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion,
|
||||||
|
backfillRequester *backfillRequester, stateIDs []string) {
|
||||||
|
|
||||||
|
servers := backfillRequester.servers
|
||||||
|
|
||||||
|
// work out which are missing
|
||||||
|
nidMap, err := r.DB.EventNIDs(ctx, stateIDs)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Warn("cannot query missing events")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
missingMap := make(map[string]*gomatrixserverlib.HeaderedEvent) // id -> event
|
||||||
|
for _, id := range stateIDs {
|
||||||
|
if _, ok := nidMap[id]; !ok {
|
||||||
|
missingMap[id] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
util.GetLogger(ctx).Infof("Fetching %d missing state events (from %d possible servers)", len(missingMap), len(servers))
|
||||||
|
|
||||||
|
// fetch the events from federation. Loop the servers first so if we find one that works we stick with them
|
||||||
|
for _, srv := range servers {
|
||||||
|
for id, ev := range missingMap {
|
||||||
|
if ev != nil {
|
||||||
|
continue // already found
|
||||||
|
}
|
||||||
|
logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id)
|
||||||
|
res, err := r.FSAPI.GetEvent(ctx, srv, id)
|
||||||
|
if err != nil {
|
||||||
|
logger.WithError(err).Warn("failed to get event from server")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
|
||||||
|
result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents)
|
||||||
|
if err != nil {
|
||||||
|
logger.WithError(err).Warn("failed to load and verify event")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logger.Infof("returned %d PDUs which made events %+v", len(res.PDUs), result)
|
||||||
|
for _, res := range result {
|
||||||
|
if res.Error != nil {
|
||||||
|
logger.WithError(err).Warn("event failed PDU checks")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
missingMap[id] = res.Event
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var newEvents []gomatrixserverlib.HeaderedEvent
|
||||||
|
for _, ev := range missingMap {
|
||||||
|
if ev != nil {
|
||||||
|
newEvents = append(newEvents, *ev)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents))
|
||||||
|
persistEvents(ctx, r.DB, newEvents)
|
||||||
|
}
|
||||||
|
|
||||||
|
// backfillRequester implements gomatrixserverlib.BackfillRequester
|
||||||
|
type backfillRequester struct {
|
||||||
|
db storage.Database
|
||||||
|
fsAPI federationSenderAPI.FederationSenderInternalAPI
|
||||||
|
thisServer gomatrixserverlib.ServerName
|
||||||
|
bwExtrems map[string][]string
|
||||||
|
|
||||||
|
// per-request state
|
||||||
|
servers []gomatrixserverlib.ServerName
|
||||||
|
eventIDToBeforeStateIDs map[string][]string
|
||||||
|
eventIDMap map[string]gomatrixserverlib.Event
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBackfillRequester(db storage.Database, fsAPI federationSenderAPI.FederationSenderInternalAPI, thisServer gomatrixserverlib.ServerName, bwExtrems map[string][]string) *backfillRequester {
|
||||||
|
return &backfillRequester{
|
||||||
|
db: db,
|
||||||
|
fsAPI: fsAPI,
|
||||||
|
thisServer: thisServer,
|
||||||
|
eventIDToBeforeStateIDs: make(map[string][]string),
|
||||||
|
eventIDMap: make(map[string]gomatrixserverlib.Event),
|
||||||
|
bwExtrems: bwExtrems,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent gomatrixserverlib.HeaderedEvent) ([]string, error) {
|
||||||
|
b.eventIDMap[targetEvent.EventID()] = targetEvent.Unwrap()
|
||||||
|
if ids, ok := b.eventIDToBeforeStateIDs[targetEvent.EventID()]; ok {
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
if len(targetEvent.PrevEventIDs()) == 0 && targetEvent.Type() == "m.room.create" && targetEvent.StateKeyEquals("") {
|
||||||
|
util.GetLogger(ctx).WithField("room_id", targetEvent.RoomID()).Info("Backfilled to the beginning of the room")
|
||||||
|
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = []string{}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
// if we have exactly 1 prev event and we know the state of the room at that prev event, then just roll forward the prev event.
|
||||||
|
// Else, we have to hit /state_ids because either we don't know the state at all at this event (new backwards extremity) or
|
||||||
|
// we don't know the result of state res to merge forks (2 or more prev_events)
|
||||||
|
if len(targetEvent.PrevEventIDs()) == 1 {
|
||||||
|
prevEventID := targetEvent.PrevEventIDs()[0]
|
||||||
|
prevEvent, ok := b.eventIDMap[prevEventID]
|
||||||
|
if !ok {
|
||||||
|
goto FederationHit
|
||||||
|
}
|
||||||
|
prevEventStateIDs, ok := b.eventIDToBeforeStateIDs[prevEventID]
|
||||||
|
if !ok {
|
||||||
|
goto FederationHit
|
||||||
|
}
|
||||||
|
newStateIDs := b.calculateNewStateIDs(targetEvent.Unwrap(), prevEvent, prevEventStateIDs)
|
||||||
|
if newStateIDs != nil {
|
||||||
|
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs
|
||||||
|
return newStateIDs, nil
|
||||||
|
}
|
||||||
|
// else we failed to calculate the new state, so fallthrough
|
||||||
|
}
|
||||||
|
|
||||||
|
FederationHit:
|
||||||
|
var lastErr error
|
||||||
|
logrus.WithField("event_id", targetEvent.EventID()).Info("Requesting /state_ids at event")
|
||||||
|
for _, srv := range b.servers { // hit any valid server
|
||||||
|
c := gomatrixserverlib.FederatedStateProvider{
|
||||||
|
FedClient: b.fsAPI,
|
||||||
|
RememberAuthEvents: false,
|
||||||
|
Server: srv,
|
||||||
|
}
|
||||||
|
res, err := c.StateIDsBeforeEvent(ctx, targetEvent)
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = res
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
return nil, lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent gomatrixserverlib.Event, prevEventStateIDs []string) []string {
|
||||||
|
newStateIDs := prevEventStateIDs[:]
|
||||||
|
if prevEvent.StateKey() == nil {
|
||||||
|
// state is the same as the previous event
|
||||||
|
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs
|
||||||
|
return newStateIDs
|
||||||
|
}
|
||||||
|
|
||||||
|
missingState := false // true if we are missing the info for a state event ID
|
||||||
|
foundEvent := false // true if we found a (type, state_key) match
|
||||||
|
// find which state ID to replace, if any
|
||||||
|
for i, id := range newStateIDs {
|
||||||
|
ev, ok := b.eventIDMap[id]
|
||||||
|
if !ok {
|
||||||
|
missingState = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// The state IDs BEFORE the target event are the state IDs BEFORE the prev_event PLUS the prev_event itself
|
||||||
|
if ev.Type() == prevEvent.Type() && ev.StateKeyEquals(*prevEvent.StateKey()) {
|
||||||
|
newStateIDs[i] = prevEvent.EventID()
|
||||||
|
foundEvent = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundEvent && !missingState {
|
||||||
|
// we can be certain that this is new state
|
||||||
|
newStateIDs = append(newStateIDs, prevEvent.EventID())
|
||||||
|
foundEvent = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if foundEvent {
|
||||||
|
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs
|
||||||
|
return newStateIDs
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatrixserverlib.RoomVersion,
|
||||||
|
event gomatrixserverlib.HeaderedEvent, eventIDs []string) (map[string]*gomatrixserverlib.Event, error) {
|
||||||
|
|
||||||
|
// try to fetch the events from the database first
|
||||||
|
events, err := b.ProvideEvents(roomVer, eventIDs)
|
||||||
|
if err != nil {
|
||||||
|
// non-fatal, fallthrough
|
||||||
|
logrus.WithError(err).Info("Failed to fetch events")
|
||||||
|
} else {
|
||||||
|
logrus.Infof("Fetched %d/%d events from the database", len(events), len(eventIDs))
|
||||||
|
if len(events) == len(eventIDs) {
|
||||||
|
result := make(map[string]*gomatrixserverlib.Event)
|
||||||
|
for i := range events {
|
||||||
|
result[events[i].EventID()] = &events[i]
|
||||||
|
b.eventIDMap[events[i].EventID()] = events[i]
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c := gomatrixserverlib.FederatedStateProvider{
|
||||||
|
FedClient: b.fsAPI,
|
||||||
|
RememberAuthEvents: false,
|
||||||
|
Server: b.servers[0],
|
||||||
|
}
|
||||||
|
result, err := c.StateBeforeEvent(ctx, roomVer, event, eventIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for eventID, ev := range result {
|
||||||
|
b.eventIDMap[eventID] = *ev
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServersAtEvent is called when trying to determine which server to request from.
|
||||||
|
// It returns a list of servers which can be queried for backfill requests. These servers
|
||||||
|
// will be servers that are in the room already. The entries at the beginning are preferred servers
|
||||||
|
// and will be tried first. An empty list will fail the request.
|
||||||
|
// nolint:gocyclo
|
||||||
|
func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID string) []gomatrixserverlib.ServerName {
|
||||||
|
// eventID will be a prev_event ID of a backwards extremity, meaning we will not have a database entry for it. Instead, use
|
||||||
|
// its successor, so look it up.
|
||||||
|
successor := ""
|
||||||
|
FindSuccessor:
|
||||||
|
for sucID, prevEventIDs := range b.bwExtrems {
|
||||||
|
for _, pe := range prevEventIDs {
|
||||||
|
if pe == eventID {
|
||||||
|
successor = sucID
|
||||||
|
break FindSuccessor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if successor == "" {
|
||||||
|
logrus.WithField("event_id", eventID).Error("ServersAtEvent: failed to find successor of this event to determine room state")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
eventID = successor
|
||||||
|
|
||||||
|
// getMembershipsBeforeEventNID requires a NID, so retrieving the NID for
|
||||||
|
// the event is necessary.
|
||||||
|
NIDs, err := b.db.EventNIDs(ctx, []string{eventID})
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get event NID for event")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := b.db.RoomInfo(ctx, roomID)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).WithField("room_id", roomID).Error("ServersAtEvent: failed to get RoomInfo for room")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if info == nil || info.IsStub {
|
||||||
|
logrus.WithField("room_id", roomID).Error("ServersAtEvent: failed to get RoomInfo for room, room is missing")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, *info, NIDs[eventID])
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// possibly return all joined servers depending on history visiblity
|
||||||
|
memberEventsFromVis, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
logrus.Infof("ServersAtEvent including %d current events from history visibility", len(memberEventsFromVis))
|
||||||
|
|
||||||
|
// Retrieve all "m.room.member" state events of "join" membership, which
|
||||||
|
// contains the list of users in the room before the event, therefore all
|
||||||
|
// the servers in it at that moment.
|
||||||
|
memberEvents, err := helpers.GetMembershipsAtState(ctx, b.db, stateEntries, true)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
memberEvents = append(memberEvents, memberEventsFromVis...)
|
||||||
|
|
||||||
|
// Store the server names in a temporary map to avoid duplicates.
|
||||||
|
serverSet := make(map[gomatrixserverlib.ServerName]bool)
|
||||||
|
for _, event := range memberEvents {
|
||||||
|
serverSet[event.Origin()] = true
|
||||||
|
}
|
||||||
|
var servers []gomatrixserverlib.ServerName
|
||||||
|
for server := range serverSet {
|
||||||
|
if server == b.thisServer {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
servers = append(servers, server)
|
||||||
|
}
|
||||||
|
b.servers = servers
|
||||||
|
return servers
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backfill performs a backfill request to the given server.
|
||||||
|
// https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid
|
||||||
|
func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverlib.ServerName, roomID string,
|
||||||
|
limit int, fromEventIDs []string) (gomatrixserverlib.Transaction, error) {
|
||||||
|
|
||||||
|
tx, err := b.fsAPI.Backfill(ctx, server, roomID, limit, fromEventIDs)
|
||||||
|
return tx, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, eventIDs []string) ([]gomatrixserverlib.Event, error) {
|
||||||
|
ctx := context.Background()
|
||||||
|
nidMap, err := b.db.EventNIDs(ctx, eventIDs)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).WithField("event_ids", eventIDs).Error("Failed to find events")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
eventNIDs := make([]types.EventNID, len(nidMap))
|
||||||
|
i := 0
|
||||||
|
for _, nid := range nidMap {
|
||||||
|
eventNIDs[i] = nid
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
eventsWithNids, err := b.db.Events(ctx, eventNIDs)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
events := make([]gomatrixserverlib.Event, len(eventsWithNids))
|
||||||
|
for i := range eventsWithNids {
|
||||||
|
events[i] = eventsWithNids[i].Event
|
||||||
|
}
|
||||||
|
return events, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// joinEventsFromHistoryVisibility returns all CURRENTLY joined members if the provided state indicated a 'shared' history visibility.
|
||||||
|
// TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just
|
||||||
|
// pull all events and then filter by that table.
|
||||||
|
func joinEventsFromHistoryVisibility(
|
||||||
|
ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry) ([]types.Event, error) {
|
||||||
|
|
||||||
|
var eventNIDs []types.EventNID
|
||||||
|
for _, entry := range stateEntries {
|
||||||
|
// Filter the events to retrieve to only keep the membership events
|
||||||
|
if entry.EventTypeNID == types.MRoomHistoryVisibilityNID && entry.EventStateKeyNID == types.EmptyStateKeyNID {
|
||||||
|
eventNIDs = append(eventNIDs, entry.EventNID)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get all of the events in this state
|
||||||
|
stateEvents, err := db.Events(ctx, eventNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
events := make([]gomatrixserverlib.Event, len(stateEvents))
|
||||||
|
for i := range stateEvents {
|
||||||
|
events[i] = stateEvents[i].Event
|
||||||
|
}
|
||||||
|
visibility := auth.HistoryVisibilityForRoom(events)
|
||||||
|
if visibility != "shared" {
|
||||||
|
logrus.Infof("ServersAtEvent history visibility not shared: %s", visibility)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
// get joined members
|
||||||
|
info, err := db.RoomInfo(ctx, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return db.Events(ctx, joinEventNIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func persistEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) {
|
||||||
|
var roomNID types.RoomNID
|
||||||
|
backfilledEventMap := make(map[string]types.Event)
|
||||||
|
for j, ev := range events {
|
||||||
|
nidMap, err := db.EventNIDs(ctx, ev.AuthEventIDs())
|
||||||
|
if err != nil { // this shouldn't happen as RequestBackfill already found them
|
||||||
|
logrus.WithError(err).WithField("auth_events", ev.AuthEventIDs()).Error("Failed to find one or more auth events")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
authNids := make([]types.EventNID, len(nidMap))
|
||||||
|
i := 0
|
||||||
|
for _, nid := range nidMap {
|
||||||
|
authNids[i] = nid
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
var stateAtEvent types.StateAtEvent
|
||||||
|
var redactedEventID string
|
||||||
|
var redactionEvent *gomatrixserverlib.Event
|
||||||
|
roomNID, stateAtEvent, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// If storing this event results in it being redacted, then do so.
|
||||||
|
// It's also possible for this event to be a redaction which results in another event being
|
||||||
|
// redacted, which we don't care about since we aren't returning it in this backfill.
|
||||||
|
if redactedEventID == ev.EventID() {
|
||||||
|
eventToRedact := ev.Unwrap()
|
||||||
|
redactedEvent, err := eventutil.RedactEvent(redactionEvent, &eventToRedact)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ev = redactedEvent.Headered(ev.RoomVersion)
|
||||||
|
events[j] = ev
|
||||||
|
}
|
||||||
|
backfilledEventMap[ev.EventID()] = types.Event{
|
||||||
|
EventNID: stateAtEvent.StateEntry.EventNID,
|
||||||
|
Event: ev.Unwrap(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return roomNID, backfilledEventMap
|
||||||
|
}
|
|
@ -1,11 +1,28 @@
|
||||||
package internal
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package perform
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
|
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
|
||||||
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/internal/input"
|
||||||
"github.com/matrix-org/dendrite/roomserver/state"
|
"github.com/matrix-org/dendrite/roomserver/state"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
@ -13,22 +30,29 @@ import (
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Inviter struct {
|
||||||
|
DB storage.Database
|
||||||
|
Cfg *config.RoomServer
|
||||||
|
FSAPI federationSenderAPI.FederationSenderInternalAPI
|
||||||
|
Inputer *input.Inputer
|
||||||
|
}
|
||||||
|
|
||||||
// nolint:gocyclo
|
// nolint:gocyclo
|
||||||
func (r *RoomserverInternalAPI) PerformInvite(
|
func (r *Inviter) PerformInvite(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *api.PerformInviteRequest,
|
req *api.PerformInviteRequest,
|
||||||
res *api.PerformInviteResponse,
|
res *api.PerformInviteResponse,
|
||||||
) error {
|
) ([]api.OutputEvent, error) {
|
||||||
event := req.Event
|
event := req.Event
|
||||||
if event.StateKey() == nil {
|
if event.StateKey() == nil {
|
||||||
return fmt.Errorf("invite must be a state event")
|
return nil, fmt.Errorf("invite must be a state event")
|
||||||
}
|
}
|
||||||
|
|
||||||
roomID := event.RoomID()
|
roomID := event.RoomID()
|
||||||
targetUserID := *event.StateKey()
|
targetUserID := *event.StateKey()
|
||||||
info, err := r.DB.RoomInfo(ctx, roomID)
|
info, err := r.DB.RoomInfo(ctx, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Failed to load RoomInfo: %w", err)
|
return nil, fmt.Errorf("Failed to load RoomInfo: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
|
@ -52,11 +76,11 @@ func (r *RoomserverInternalAPI) PerformInvite(
|
||||||
}
|
}
|
||||||
if len(inviteState) == 0 {
|
if len(inviteState) == 0 {
|
||||||
if err = event.SetUnsignedField("invite_room_state", struct{}{}); err != nil {
|
if err = event.SetUnsignedField("invite_room_state", struct{}{}); err != nil {
|
||||||
return fmt.Errorf("event.SetUnsignedField: %w", err)
|
return nil, fmt.Errorf("event.SetUnsignedField: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err = event.SetUnsignedField("invite_room_state", inviteState); err != nil {
|
if err = event.SetUnsignedField("invite_room_state", inviteState); err != nil {
|
||||||
return fmt.Errorf("event.SetUnsignedField: %w", err)
|
return nil, fmt.Errorf("event.SetUnsignedField: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,7 +88,7 @@ func (r *RoomserverInternalAPI) PerformInvite(
|
||||||
if info != nil {
|
if info != nil {
|
||||||
_, isAlreadyJoined, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey())
|
_, isAlreadyJoined, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("r.DB.GetMembership: %w", err)
|
return nil, fmt.Errorf("r.DB.GetMembership: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if isAlreadyJoined {
|
if isAlreadyJoined {
|
||||||
|
@ -99,7 +123,7 @@ func (r *RoomserverInternalAPI) PerformInvite(
|
||||||
Code: api.PerformErrorNotAllowed,
|
Code: api.PerformErrorNotAllowed,
|
||||||
Msg: "User is already joined to room",
|
Msg: "User is already joined to room",
|
||||||
}
|
}
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if isOriginLocal {
|
if isOriginLocal {
|
||||||
|
@ -107,7 +131,7 @@ func (r *RoomserverInternalAPI) PerformInvite(
|
||||||
// try and see if the user is allowed to make this invite. We can't do
|
// try and see if the user is allowed to make this invite. We can't do
|
||||||
// this for invites coming in over federation - we have to take those on
|
// this for invites coming in over federation - we have to take those on
|
||||||
// trust.
|
// trust.
|
||||||
_, err = checkAuthEvents(ctx, r.DB, event, event.AuthEventIDs())
|
_, err = helpers.CheckAuthEvents(ctx, r.DB, event, event.AuthEventIDs())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error(
|
log.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error(
|
||||||
"processInviteEvent.checkAuthEvents failed for event",
|
"processInviteEvent.checkAuthEvents failed for event",
|
||||||
|
@ -117,9 +141,9 @@ func (r *RoomserverInternalAPI) PerformInvite(
|
||||||
Msg: err.Error(),
|
Msg: err.Error(),
|
||||||
Code: api.PerformErrorNotAllowed,
|
Code: api.PerformErrorNotAllowed,
|
||||||
}
|
}
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return fmt.Errorf("checkAuthEvents: %w", err)
|
return nil, fmt.Errorf("checkAuthEvents: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the invite originated from us and the target isn't local then we
|
// If the invite originated from us and the target isn't local then we
|
||||||
|
@ -133,13 +157,13 @@ func (r *RoomserverInternalAPI) PerformInvite(
|
||||||
InviteRoomState: inviteState,
|
InviteRoomState: inviteState,
|
||||||
}
|
}
|
||||||
fsRes := &federationSenderAPI.PerformInviteResponse{}
|
fsRes := &federationSenderAPI.PerformInviteResponse{}
|
||||||
if err = r.fsAPI.PerformInvite(ctx, fsReq, fsRes); err != nil {
|
if err = r.FSAPI.PerformInvite(ctx, fsReq, fsRes); err != nil {
|
||||||
res.Error = &api.PerformError{
|
res.Error = &api.PerformError{
|
||||||
Msg: err.Error(),
|
Msg: err.Error(),
|
||||||
Code: api.PerformErrorNoOperation,
|
Code: api.PerformErrorNoOperation,
|
||||||
}
|
}
|
||||||
log.WithError(err).WithField("event_id", event.EventID()).Error("r.fsAPI.PerformInvite failed")
|
log.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed")
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
event = fsRes.Event
|
event = fsRes.Event
|
||||||
}
|
}
|
||||||
|
@ -159,8 +183,8 @@ func (r *RoomserverInternalAPI) PerformInvite(
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
inputRes := &api.InputRoomEventsResponse{}
|
inputRes := &api.InputRoomEventsResponse{}
|
||||||
if err = r.InputRoomEvents(context.Background(), inputReq, inputRes); err != nil {
|
if err = r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes); err != nil {
|
||||||
return fmt.Errorf("r.InputRoomEvents: %w", err)
|
return nil, fmt.Errorf("r.InputRoomEvents: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// The invite originated over federation. Process the membership
|
// The invite originated over federation. Process the membership
|
||||||
|
@ -168,25 +192,23 @@ func (r *RoomserverInternalAPI) PerformInvite(
|
||||||
// invite.
|
// invite.
|
||||||
updater, err := r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion)
|
updater, err := r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("r.DB.MembershipUpdater: %w", err)
|
return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
unwrapped := event.Unwrap()
|
unwrapped := event.Unwrap()
|
||||||
outputUpdates, err := updateToInviteMembership(updater, &unwrapped, nil, req.Event.RoomVersion)
|
outputUpdates, err := helpers.UpdateToInviteMembership(updater, &unwrapped, nil, req.Event.RoomVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("updateToInviteMembership: %w", err)
|
return nil, fmt.Errorf("updateToInviteMembership: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = updater.Commit(); err != nil {
|
if err = updater.Commit(); err != nil {
|
||||||
return fmt.Errorf("updater.Commit: %w", err)
|
return nil, fmt.Errorf("updater.Commit: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = r.WriteOutputEvents(roomID, outputUpdates); err != nil {
|
return outputUpdates, nil
|
||||||
return fmt.Errorf("r.WriteOutputEvents: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildInviteStrippedState(
|
func buildInviteStrippedState(
|
||||||
|
@ -208,7 +230,7 @@ func buildInviteStrippedState(
|
||||||
StateKey: "",
|
StateKey: "",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
roomState := state.NewStateResolution(db)
|
roomState := state.NewStateResolution(db, *info)
|
||||||
stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples(
|
stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples(
|
||||||
ctx, info.StateSnapshotNID, stateWanted,
|
ctx, info.StateSnapshotNID, stateWanted,
|
||||||
)
|
)
|
|
@ -1,4 +1,18 @@
|
||||||
package internal
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package perform
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -8,14 +22,27 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
fsAPI "github.com/matrix-org/dendrite/federationsender/api"
|
fsAPI "github.com/matrix-org/dendrite/federationsender/api"
|
||||||
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/internal/input"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Joiner struct {
|
||||||
|
ServerName gomatrixserverlib.ServerName
|
||||||
|
Cfg *config.RoomServer
|
||||||
|
FSAPI fsAPI.FederationSenderInternalAPI
|
||||||
|
DB storage.Database
|
||||||
|
|
||||||
|
Inputer *input.Inputer
|
||||||
|
}
|
||||||
|
|
||||||
// PerformJoin handles joining matrix rooms, including over federation by talking to the federationsender.
|
// PerformJoin handles joining matrix rooms, including over federation by talking to the federationsender.
|
||||||
func (r *RoomserverInternalAPI) PerformJoin(
|
func (r *Joiner) PerformJoin(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *api.PerformJoinRequest,
|
req *api.PerformJoinRequest,
|
||||||
res *api.PerformJoinResponse,
|
res *api.PerformJoinResponse,
|
||||||
|
@ -34,7 +61,7 @@ func (r *RoomserverInternalAPI) PerformJoin(
|
||||||
res.RoomID = roomID
|
res.RoomID = roomID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) performJoin(
|
func (r *Joiner) performJoin(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *api.PerformJoinRequest,
|
req *api.PerformJoinRequest,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
|
@ -63,7 +90,7 @@ func (r *RoomserverInternalAPI) performJoin(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) performJoinRoomByAlias(
|
func (r *Joiner) performJoinRoomByAlias(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *api.PerformJoinRequest,
|
req *api.PerformJoinRequest,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
|
@ -85,7 +112,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByAlias(
|
||||||
ServerName: domain, // the server to ask
|
ServerName: domain, // the server to ask
|
||||||
}
|
}
|
||||||
dirRes := fsAPI.PerformDirectoryLookupResponse{}
|
dirRes := fsAPI.PerformDirectoryLookupResponse{}
|
||||||
err = r.fsAPI.PerformDirectoryLookup(ctx, &dirReq, &dirRes)
|
err = r.FSAPI.PerformDirectoryLookup(ctx, &dirReq, &dirRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Errorf("error looking up alias %q", req.RoomIDOrAlias)
|
logrus.WithError(err).Errorf("error looking up alias %q", req.RoomIDOrAlias)
|
||||||
return "", fmt.Errorf("Looking up alias %q over federation failed: %w", req.RoomIDOrAlias, err)
|
return "", fmt.Errorf("Looking up alias %q over federation failed: %w", req.RoomIDOrAlias, err)
|
||||||
|
@ -112,7 +139,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByAlias(
|
||||||
|
|
||||||
// TODO: Break this function up a bit
|
// TODO: Break this function up a bit
|
||||||
// nolint:gocyclo
|
// nolint:gocyclo
|
||||||
func (r *RoomserverInternalAPI) performJoinRoomByID(
|
func (r *Joiner) performJoinRoomByID(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *api.PerformJoinRequest,
|
req *api.PerformJoinRequest,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
|
@ -161,8 +188,8 @@ func (r *RoomserverInternalAPI) performJoinRoomByID(
|
||||||
// where we might think we know about a room in the following
|
// where we might think we know about a room in the following
|
||||||
// section but don't know the latest state as all of our users
|
// section but don't know the latest state as all of our users
|
||||||
// have left.
|
// have left.
|
||||||
serverInRoom, _ := r.isServerCurrentlyInRoom(ctx, r.ServerName, req.RoomIDOrAlias)
|
serverInRoom, _ := helpers.IsServerCurrentlyInRoom(ctx, r.DB, r.ServerName, req.RoomIDOrAlias)
|
||||||
isInvitePending, inviteSender, _, err := r.isInvitePending(ctx, req.RoomIDOrAlias, req.UserID)
|
isInvitePending, inviteSender, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, req.UserID)
|
||||||
if err == nil && isInvitePending && !serverInRoom {
|
if err == nil && isInvitePending && !serverInRoom {
|
||||||
// Check if there's an invite pending.
|
// Check if there's an invite pending.
|
||||||
_, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender)
|
_, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender)
|
||||||
|
@ -188,15 +215,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByID(
|
||||||
// locally on the homeserver.
|
// locally on the homeserver.
|
||||||
// TODO: Check what happens if the room exists on the server
|
// TODO: Check what happens if the room exists on the server
|
||||||
// but everyone has since left. I suspect it does the wrong thing.
|
// but everyone has since left. I suspect it does the wrong thing.
|
||||||
buildRes := api.QueryLatestEventsAndStateResponse{}
|
event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, &eb)
|
||||||
event, err := eventutil.BuildEvent(
|
|
||||||
ctx, // the request context
|
|
||||||
&eb, // the template join event
|
|
||||||
r.Cfg.Matrix, // the server configuration
|
|
||||||
time.Now(), // the event timestamp to use
|
|
||||||
r, // the roomserver API to use
|
|
||||||
&buildRes, // the query response
|
|
||||||
)
|
|
||||||
|
|
||||||
switch err {
|
switch err {
|
||||||
case nil:
|
case nil:
|
||||||
|
@ -228,7 +247,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByID(
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
inputRes := api.InputRoomEventsResponse{}
|
inputRes := api.InputRoomEventsResponse{}
|
||||||
if err = r.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil {
|
if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil {
|
||||||
var notAllowed *gomatrixserverlib.NotAllowed
|
var notAllowed *gomatrixserverlib.NotAllowed
|
||||||
if errors.As(err, ¬Allowed) {
|
if errors.As(err, ¬Allowed) {
|
||||||
return "", &api.PerformError{
|
return "", &api.PerformError{
|
||||||
|
@ -271,7 +290,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByID(
|
||||||
return req.RoomIDOrAlias, nil
|
return req.RoomIDOrAlias, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) performFederatedJoinRoomByID(
|
func (r *Joiner) performFederatedJoinRoomByID(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req *api.PerformJoinRequest,
|
req *api.PerformJoinRequest,
|
||||||
) error {
|
) error {
|
||||||
|
@ -283,7 +302,7 @@ func (r *RoomserverInternalAPI) performFederatedJoinRoomByID(
|
||||||
Content: req.Content, // the membership event content
|
Content: req.Content, // the membership event content
|
||||||
}
|
}
|
||||||
fedRes := fsAPI.PerformJoinResponse{}
|
fedRes := fsAPI.PerformJoinResponse{}
|
||||||
r.fsAPI.PerformJoin(ctx, &fedReq, &fedRes)
|
r.FSAPI.PerformJoin(ctx, &fedReq, &fedRes)
|
||||||
if fedRes.LastError != nil {
|
if fedRes.LastError != nil {
|
||||||
return &api.PerformError{
|
return &api.PerformError{
|
||||||
Code: api.PerformErrRemote,
|
Code: api.PerformErrRemote,
|
||||||
|
@ -293,3 +312,31 @@ func (r *RoomserverInternalAPI) performFederatedJoinRoomByID(
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildEvent(
|
||||||
|
ctx context.Context, db storage.Database, cfg *config.Global, builder *gomatrixserverlib.EventBuilder,
|
||||||
|
) (*gomatrixserverlib.HeaderedEvent, *api.QueryLatestEventsAndStateResponse, error) {
|
||||||
|
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("gomatrixserverlib.StateNeededForEventBuilder: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(eventsNeeded.Tuples()) == 0 {
|
||||||
|
return nil, nil, errors.New("expecting state tuples for event builder, got none")
|
||||||
|
}
|
||||||
|
|
||||||
|
var queryRes api.QueryLatestEventsAndStateResponse
|
||||||
|
err = helpers.QueryLatestEventsAndState(ctx, db, &api.QueryLatestEventsAndStateRequest{
|
||||||
|
RoomID: builder.RoomID,
|
||||||
|
StateToFetch: eventsNeeded.Tuples(),
|
||||||
|
}, &queryRes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("QueryLatestEventsAndState: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ev, err := eventutil.BuildEvent(ctx, builder, cfg, time.Now(), &eventsNeeded, &queryRes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return ev, &queryRes, nil
|
||||||
|
}
|
183
roomserver/internal/perform/perform_leave.go
Normal file
183
roomserver/internal/perform/perform_leave.go
Normal file
|
@ -0,0 +1,183 @@
|
||||||
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package perform
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
fsAPI "github.com/matrix-org/dendrite/federationsender/api"
|
||||||
|
"github.com/matrix-org/dendrite/internal/config"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/internal/input"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Leaver struct {
|
||||||
|
Cfg *config.RoomServer
|
||||||
|
DB storage.Database
|
||||||
|
FSAPI fsAPI.FederationSenderInternalAPI
|
||||||
|
|
||||||
|
Inputer *input.Inputer
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteOutputEvents implements OutputRoomEventWriter
|
||||||
|
func (r *Leaver) PerformLeave(
|
||||||
|
ctx context.Context,
|
||||||
|
req *api.PerformLeaveRequest,
|
||||||
|
res *api.PerformLeaveResponse,
|
||||||
|
) ([]api.OutputEvent, error) {
|
||||||
|
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Supplied user ID %q in incorrect format", req.UserID)
|
||||||
|
}
|
||||||
|
if domain != r.Cfg.Matrix.ServerName {
|
||||||
|
return nil, fmt.Errorf("User %q does not belong to this homeserver", req.UserID)
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(req.RoomID, "!") {
|
||||||
|
return r.performLeaveRoomByID(ctx, req, res)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("Room ID %q is invalid", req.RoomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Leaver) performLeaveRoomByID(
|
||||||
|
ctx context.Context,
|
||||||
|
req *api.PerformLeaveRequest,
|
||||||
|
res *api.PerformLeaveResponse, // nolint:unparam
|
||||||
|
) ([]api.OutputEvent, error) {
|
||||||
|
// If there's an invite outstanding for the room then respond to
|
||||||
|
// that.
|
||||||
|
isInvitePending, senderUser, eventID, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, req.UserID)
|
||||||
|
if err == nil && isInvitePending {
|
||||||
|
return r.performRejectInvite(ctx, req, res, senderUser, eventID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// There's no invite pending, so first of all we want to find out
|
||||||
|
// if the room exists and if the user is actually in it.
|
||||||
|
latestReq := api.QueryLatestEventsAndStateRequest{
|
||||||
|
RoomID: req.RoomID,
|
||||||
|
StateToFetch: []gomatrixserverlib.StateKeyTuple{
|
||||||
|
{
|
||||||
|
EventType: gomatrixserverlib.MRoomMember,
|
||||||
|
StateKey: req.UserID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
latestRes := api.QueryLatestEventsAndStateResponse{}
|
||||||
|
if err = helpers.QueryLatestEventsAndState(ctx, r.DB, &latestReq, &latestRes); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !latestRes.RoomExists {
|
||||||
|
return nil, fmt.Errorf("Room %q does not exist", req.RoomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now let's see if the user is in the room.
|
||||||
|
if len(latestRes.StateEvents) == 0 {
|
||||||
|
return nil, fmt.Errorf("User %q is not a member of room %q", req.UserID, req.RoomID)
|
||||||
|
}
|
||||||
|
membership, err := latestRes.StateEvents[0].Membership()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Error getting membership: %w", err)
|
||||||
|
}
|
||||||
|
if membership != gomatrixserverlib.Join {
|
||||||
|
// TODO: should be able to handle "invite" in this case too, if
|
||||||
|
// it's a case of kicking or banning or such
|
||||||
|
return nil, fmt.Errorf("User %q is not joined to the room (membership is %q)", req.UserID, membership)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare the template for the leave event.
|
||||||
|
userID := req.UserID
|
||||||
|
eb := gomatrixserverlib.EventBuilder{
|
||||||
|
Type: gomatrixserverlib.MRoomMember,
|
||||||
|
Sender: userID,
|
||||||
|
StateKey: &userID,
|
||||||
|
RoomID: req.RoomID,
|
||||||
|
Redacts: "",
|
||||||
|
}
|
||||||
|
if err = eb.SetContent(map[string]interface{}{"membership": "leave"}); err != nil {
|
||||||
|
return nil, fmt.Errorf("eb.SetContent: %w", err)
|
||||||
|
}
|
||||||
|
if err = eb.SetUnsigned(struct{}{}); err != nil {
|
||||||
|
return nil, fmt.Errorf("eb.SetUnsigned: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We know that the user is in the room at this point so let's build
|
||||||
|
// a leave event.
|
||||||
|
// TODO: Check what happens if the room exists on the server
|
||||||
|
// but everyone has since left. I suspect it does the wrong thing.
|
||||||
|
event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, &eb)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("eventutil.BuildEvent: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Give our leave event to the roomserver input stream. The
|
||||||
|
// roomserver will process the membership change and notify
|
||||||
|
// downstream automatically.
|
||||||
|
inputReq := api.InputRoomEventsRequest{
|
||||||
|
InputRoomEvents: []api.InputRoomEvent{
|
||||||
|
{
|
||||||
|
Kind: api.KindNew,
|
||||||
|
Event: event.Headered(buildRes.RoomVersion),
|
||||||
|
AuthEventIDs: event.AuthEventIDs(),
|
||||||
|
SendAsServer: string(r.Cfg.Matrix.ServerName),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
inputRes := api.InputRoomEventsResponse{}
|
||||||
|
if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil {
|
||||||
|
return nil, fmt.Errorf("r.InputRoomEvents: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Leaver) performRejectInvite(
|
||||||
|
ctx context.Context,
|
||||||
|
req *api.PerformLeaveRequest,
|
||||||
|
res *api.PerformLeaveResponse, // nolint:unparam
|
||||||
|
senderUser, eventID string,
|
||||||
|
) ([]api.OutputEvent, error) {
|
||||||
|
_, domain, err := gomatrixserverlib.SplitID('@', senderUser)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("User ID %q invalid: %w", senderUser, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ask the federation sender to perform a federated leave for us.
|
||||||
|
leaveReq := fsAPI.PerformLeaveRequest{
|
||||||
|
RoomID: req.RoomID,
|
||||||
|
UserID: req.UserID,
|
||||||
|
ServerNames: []gomatrixserverlib.ServerName{domain},
|
||||||
|
}
|
||||||
|
leaveRes := fsAPI.PerformLeaveResponse{}
|
||||||
|
if err := r.FSAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Withdraw the invite, so that the sync API etc are
|
||||||
|
// notified that we rejected it.
|
||||||
|
return []api.OutputEvent{
|
||||||
|
{
|
||||||
|
Type: api.OutputTypeRetireInviteEvent,
|
||||||
|
RetireInviteEvent: &api.OutputRetireInviteEvent{
|
||||||
|
EventID: eventID,
|
||||||
|
Membership: "leave",
|
||||||
|
TargetUserID: req.UserID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
39
roomserver/internal/perform/perform_publish.go
Normal file
39
roomserver/internal/perform/perform_publish.go
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package perform
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Publisher struct {
|
||||||
|
DB storage.Database
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Publisher) PerformPublish(
|
||||||
|
ctx context.Context,
|
||||||
|
req *api.PerformPublishRequest,
|
||||||
|
res *api.PerformPublishResponse,
|
||||||
|
) {
|
||||||
|
err := r.DB.PublishRoom(ctx, req.RoomID, req.Visibility == "public")
|
||||||
|
if err != nil {
|
||||||
|
res.Error = &api.PerformError{
|
||||||
|
Msg: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,305 +0,0 @@
|
||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/auth"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
"github.com/matrix-org/util"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
// backfillRequester implements gomatrixserverlib.BackfillRequester
|
|
||||||
type backfillRequester struct {
|
|
||||||
db storage.Database
|
|
||||||
fedClient *gomatrixserverlib.FederationClient
|
|
||||||
thisServer gomatrixserverlib.ServerName
|
|
||||||
bwExtrems map[string][]string
|
|
||||||
|
|
||||||
// per-request state
|
|
||||||
servers []gomatrixserverlib.ServerName
|
|
||||||
eventIDToBeforeStateIDs map[string][]string
|
|
||||||
eventIDMap map[string]gomatrixserverlib.Event
|
|
||||||
}
|
|
||||||
|
|
||||||
func newBackfillRequester(db storage.Database, fedClient *gomatrixserverlib.FederationClient, thisServer gomatrixserverlib.ServerName, bwExtrems map[string][]string) *backfillRequester {
|
|
||||||
return &backfillRequester{
|
|
||||||
db: db,
|
|
||||||
fedClient: fedClient,
|
|
||||||
thisServer: thisServer,
|
|
||||||
eventIDToBeforeStateIDs: make(map[string][]string),
|
|
||||||
eventIDMap: make(map[string]gomatrixserverlib.Event),
|
|
||||||
bwExtrems: bwExtrems,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent gomatrixserverlib.HeaderedEvent) ([]string, error) {
|
|
||||||
b.eventIDMap[targetEvent.EventID()] = targetEvent.Unwrap()
|
|
||||||
if ids, ok := b.eventIDToBeforeStateIDs[targetEvent.EventID()]; ok {
|
|
||||||
return ids, nil
|
|
||||||
}
|
|
||||||
if len(targetEvent.PrevEventIDs()) == 0 && targetEvent.Type() == "m.room.create" && targetEvent.StateKeyEquals("") {
|
|
||||||
util.GetLogger(ctx).WithField("room_id", targetEvent.RoomID()).Info("Backfilled to the beginning of the room")
|
|
||||||
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = []string{}
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
// if we have exactly 1 prev event and we know the state of the room at that prev event, then just roll forward the prev event.
|
|
||||||
// Else, we have to hit /state_ids because either we don't know the state at all at this event (new backwards extremity) or
|
|
||||||
// we don't know the result of state res to merge forks (2 or more prev_events)
|
|
||||||
if len(targetEvent.PrevEventIDs()) == 1 {
|
|
||||||
prevEventID := targetEvent.PrevEventIDs()[0]
|
|
||||||
prevEvent, ok := b.eventIDMap[prevEventID]
|
|
||||||
if !ok {
|
|
||||||
goto FederationHit
|
|
||||||
}
|
|
||||||
prevEventStateIDs, ok := b.eventIDToBeforeStateIDs[prevEventID]
|
|
||||||
if !ok {
|
|
||||||
goto FederationHit
|
|
||||||
}
|
|
||||||
newStateIDs := b.calculateNewStateIDs(targetEvent.Unwrap(), prevEvent, prevEventStateIDs)
|
|
||||||
if newStateIDs != nil {
|
|
||||||
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs
|
|
||||||
return newStateIDs, nil
|
|
||||||
}
|
|
||||||
// else we failed to calculate the new state, so fallthrough
|
|
||||||
}
|
|
||||||
|
|
||||||
FederationHit:
|
|
||||||
var lastErr error
|
|
||||||
logrus.WithField("event_id", targetEvent.EventID()).Info("Requesting /state_ids at event")
|
|
||||||
for _, srv := range b.servers { // hit any valid server
|
|
||||||
c := gomatrixserverlib.FederatedStateProvider{
|
|
||||||
FedClient: b.fedClient,
|
|
||||||
RememberAuthEvents: false,
|
|
||||||
Server: srv,
|
|
||||||
}
|
|
||||||
res, err := c.StateIDsBeforeEvent(ctx, targetEvent)
|
|
||||||
if err != nil {
|
|
||||||
lastErr = err
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = res
|
|
||||||
return res, nil
|
|
||||||
}
|
|
||||||
return nil, lastErr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent gomatrixserverlib.Event, prevEventStateIDs []string) []string {
|
|
||||||
newStateIDs := prevEventStateIDs[:]
|
|
||||||
if prevEvent.StateKey() == nil {
|
|
||||||
// state is the same as the previous event
|
|
||||||
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs
|
|
||||||
return newStateIDs
|
|
||||||
}
|
|
||||||
|
|
||||||
missingState := false // true if we are missing the info for a state event ID
|
|
||||||
foundEvent := false // true if we found a (type, state_key) match
|
|
||||||
// find which state ID to replace, if any
|
|
||||||
for i, id := range newStateIDs {
|
|
||||||
ev, ok := b.eventIDMap[id]
|
|
||||||
if !ok {
|
|
||||||
missingState = true
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// The state IDs BEFORE the target event are the state IDs BEFORE the prev_event PLUS the prev_event itself
|
|
||||||
if ev.Type() == prevEvent.Type() && ev.StateKeyEquals(*prevEvent.StateKey()) {
|
|
||||||
newStateIDs[i] = prevEvent.EventID()
|
|
||||||
foundEvent = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !foundEvent && !missingState {
|
|
||||||
// we can be certain that this is new state
|
|
||||||
newStateIDs = append(newStateIDs, prevEvent.EventID())
|
|
||||||
foundEvent = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if foundEvent {
|
|
||||||
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs
|
|
||||||
return newStateIDs
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatrixserverlib.RoomVersion,
|
|
||||||
event gomatrixserverlib.HeaderedEvent, eventIDs []string) (map[string]*gomatrixserverlib.Event, error) {
|
|
||||||
|
|
||||||
// try to fetch the events from the database first
|
|
||||||
events, err := b.ProvideEvents(roomVer, eventIDs)
|
|
||||||
if err != nil {
|
|
||||||
// non-fatal, fallthrough
|
|
||||||
logrus.WithError(err).Info("Failed to fetch events")
|
|
||||||
} else {
|
|
||||||
logrus.Infof("Fetched %d/%d events from the database", len(events), len(eventIDs))
|
|
||||||
if len(events) == len(eventIDs) {
|
|
||||||
result := make(map[string]*gomatrixserverlib.Event)
|
|
||||||
for i := range events {
|
|
||||||
result[events[i].EventID()] = &events[i]
|
|
||||||
b.eventIDMap[events[i].EventID()] = events[i]
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c := gomatrixserverlib.FederatedStateProvider{
|
|
||||||
FedClient: b.fedClient,
|
|
||||||
RememberAuthEvents: false,
|
|
||||||
Server: b.servers[0],
|
|
||||||
}
|
|
||||||
result, err := c.StateBeforeEvent(ctx, roomVer, event, eventIDs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
for eventID, ev := range result {
|
|
||||||
b.eventIDMap[eventID] = *ev
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServersAtEvent is called when trying to determine which server to request from.
|
|
||||||
// It returns a list of servers which can be queried for backfill requests. These servers
|
|
||||||
// will be servers that are in the room already. The entries at the beginning are preferred servers
|
|
||||||
// and will be tried first. An empty list will fail the request.
|
|
||||||
func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID string) []gomatrixserverlib.ServerName {
|
|
||||||
// eventID will be a prev_event ID of a backwards extremity, meaning we will not have a database entry for it. Instead, use
|
|
||||||
// its successor, so look it up.
|
|
||||||
successor := ""
|
|
||||||
FindSuccessor:
|
|
||||||
for sucID, prevEventIDs := range b.bwExtrems {
|
|
||||||
for _, pe := range prevEventIDs {
|
|
||||||
if pe == eventID {
|
|
||||||
successor = sucID
|
|
||||||
break FindSuccessor
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if successor == "" {
|
|
||||||
logrus.WithField("event_id", eventID).Error("ServersAtEvent: failed to find successor of this event to determine room state")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
eventID = successor
|
|
||||||
|
|
||||||
// getMembershipsBeforeEventNID requires a NID, so retrieving the NID for
|
|
||||||
// the event is necessary.
|
|
||||||
NIDs, err := b.db.EventNIDs(ctx, []string{eventID})
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get event NID for event")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
stateEntries, err := stateBeforeEvent(ctx, b.db, NIDs[eventID])
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// possibly return all joined servers depending on history visiblity
|
|
||||||
memberEventsFromVis, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries)
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
logrus.Infof("ServersAtEvent including %d current events from history visibility", len(memberEventsFromVis))
|
|
||||||
|
|
||||||
// Retrieve all "m.room.member" state events of "join" membership, which
|
|
||||||
// contains the list of users in the room before the event, therefore all
|
|
||||||
// the servers in it at that moment.
|
|
||||||
memberEvents, err := getMembershipsAtState(ctx, b.db, stateEntries, true)
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
memberEvents = append(memberEvents, memberEventsFromVis...)
|
|
||||||
|
|
||||||
// Store the server names in a temporary map to avoid duplicates.
|
|
||||||
serverSet := make(map[gomatrixserverlib.ServerName]bool)
|
|
||||||
for _, event := range memberEvents {
|
|
||||||
serverSet[event.Origin()] = true
|
|
||||||
}
|
|
||||||
var servers []gomatrixserverlib.ServerName
|
|
||||||
for server := range serverSet {
|
|
||||||
if server == b.thisServer {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
servers = append(servers, server)
|
|
||||||
}
|
|
||||||
b.servers = servers
|
|
||||||
return servers
|
|
||||||
}
|
|
||||||
|
|
||||||
// Backfill performs a backfill request to the given server.
|
|
||||||
// https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid
|
|
||||||
func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverlib.ServerName, roomID string,
|
|
||||||
fromEventIDs []string, limit int) (*gomatrixserverlib.Transaction, error) {
|
|
||||||
|
|
||||||
tx, err := b.fedClient.Backfill(ctx, server, roomID, limit, fromEventIDs)
|
|
||||||
return &tx, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, eventIDs []string) ([]gomatrixserverlib.Event, error) {
|
|
||||||
ctx := context.Background()
|
|
||||||
nidMap, err := b.db.EventNIDs(ctx, eventIDs)
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithError(err).WithField("event_ids", eventIDs).Error("Failed to find events")
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
eventNIDs := make([]types.EventNID, len(nidMap))
|
|
||||||
i := 0
|
|
||||||
for _, nid := range nidMap {
|
|
||||||
eventNIDs[i] = nid
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
eventsWithNids, err := b.db.Events(ctx, eventNIDs)
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events")
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
events := make([]gomatrixserverlib.Event, len(eventsWithNids))
|
|
||||||
for i := range eventsWithNids {
|
|
||||||
events[i] = eventsWithNids[i].Event
|
|
||||||
}
|
|
||||||
return events, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// joinEventsFromHistoryVisibility returns all CURRENTLY joined members if the provided state indicated a 'shared' history visibility.
|
|
||||||
// TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just
|
|
||||||
// pull all events and then filter by that table.
|
|
||||||
func joinEventsFromHistoryVisibility(
|
|
||||||
ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry) ([]types.Event, error) {
|
|
||||||
|
|
||||||
var eventNIDs []types.EventNID
|
|
||||||
for _, entry := range stateEntries {
|
|
||||||
// Filter the events to retrieve to only keep the membership events
|
|
||||||
if entry.EventTypeNID == types.MRoomHistoryVisibilityNID && entry.EventStateKeyNID == types.EmptyStateKeyNID {
|
|
||||||
eventNIDs = append(eventNIDs, entry.EventNID)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get all of the events in this state
|
|
||||||
stateEvents, err := db.Events(ctx, eventNIDs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
events := make([]gomatrixserverlib.Event, len(stateEvents))
|
|
||||||
for i := range stateEvents {
|
|
||||||
events[i] = stateEvents[i].Event
|
|
||||||
}
|
|
||||||
visibility := auth.HistoryVisibilityForRoom(events)
|
|
||||||
if visibility != "shared" {
|
|
||||||
logrus.Infof("ServersAtEvent history visibility not shared: %s", visibility)
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
// get joined members
|
|
||||||
info, err := db.RoomInfo(ctx, roomID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return db.Events(ctx, joinEventNIDs)
|
|
||||||
}
|
|
|
@ -1,223 +0,0 @@
|
||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
fsAPI "github.com/matrix-org/dendrite/federationsender/api"
|
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
|
||||||
|
|
||||||
// WriteOutputEvents implements OutputRoomEventWriter
|
|
||||||
func (r *RoomserverInternalAPI) PerformLeave(
|
|
||||||
ctx context.Context,
|
|
||||||
req *api.PerformLeaveRequest,
|
|
||||||
res *api.PerformLeaveResponse,
|
|
||||||
) error {
|
|
||||||
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Supplied user ID %q in incorrect format", req.UserID)
|
|
||||||
}
|
|
||||||
if domain != r.Cfg.Matrix.ServerName {
|
|
||||||
return fmt.Errorf("User %q does not belong to this homeserver", req.UserID)
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(req.RoomID, "!") {
|
|
||||||
return r.performLeaveRoomByID(ctx, req, res)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("Room ID %q is invalid", req.RoomID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) performLeaveRoomByID(
|
|
||||||
ctx context.Context,
|
|
||||||
req *api.PerformLeaveRequest,
|
|
||||||
res *api.PerformLeaveResponse, // nolint:unparam
|
|
||||||
) error {
|
|
||||||
// If there's an invite outstanding for the room then respond to
|
|
||||||
// that.
|
|
||||||
isInvitePending, senderUser, eventID, err := r.isInvitePending(ctx, req.RoomID, req.UserID)
|
|
||||||
if err == nil && isInvitePending {
|
|
||||||
return r.performRejectInvite(ctx, req, res, senderUser, eventID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// There's no invite pending, so first of all we want to find out
|
|
||||||
// if the room exists and if the user is actually in it.
|
|
||||||
latestReq := api.QueryLatestEventsAndStateRequest{
|
|
||||||
RoomID: req.RoomID,
|
|
||||||
StateToFetch: []gomatrixserverlib.StateKeyTuple{
|
|
||||||
{
|
|
||||||
EventType: gomatrixserverlib.MRoomMember,
|
|
||||||
StateKey: req.UserID,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
latestRes := api.QueryLatestEventsAndStateResponse{}
|
|
||||||
if err = r.QueryLatestEventsAndState(ctx, &latestReq, &latestRes); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !latestRes.RoomExists {
|
|
||||||
return fmt.Errorf("Room %q does not exist", req.RoomID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now let's see if the user is in the room.
|
|
||||||
if len(latestRes.StateEvents) == 0 {
|
|
||||||
return fmt.Errorf("User %q is not a member of room %q", req.UserID, req.RoomID)
|
|
||||||
}
|
|
||||||
membership, err := latestRes.StateEvents[0].Membership()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Error getting membership: %w", err)
|
|
||||||
}
|
|
||||||
if membership != gomatrixserverlib.Join {
|
|
||||||
// TODO: should be able to handle "invite" in this case too, if
|
|
||||||
// it's a case of kicking or banning or such
|
|
||||||
return fmt.Errorf("User %q is not joined to the room (membership is %q)", req.UserID, membership)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare the template for the leave event.
|
|
||||||
userID := req.UserID
|
|
||||||
eb := gomatrixserverlib.EventBuilder{
|
|
||||||
Type: gomatrixserverlib.MRoomMember,
|
|
||||||
Sender: userID,
|
|
||||||
StateKey: &userID,
|
|
||||||
RoomID: req.RoomID,
|
|
||||||
Redacts: "",
|
|
||||||
}
|
|
||||||
if err = eb.SetContent(map[string]interface{}{"membership": "leave"}); err != nil {
|
|
||||||
return fmt.Errorf("eb.SetContent: %w", err)
|
|
||||||
}
|
|
||||||
if err = eb.SetUnsigned(struct{}{}); err != nil {
|
|
||||||
return fmt.Errorf("eb.SetUnsigned: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We know that the user is in the room at this point so let's build
|
|
||||||
// a leave event.
|
|
||||||
// TODO: Check what happens if the room exists on the server
|
|
||||||
// but everyone has since left. I suspect it does the wrong thing.
|
|
||||||
buildRes := api.QueryLatestEventsAndStateResponse{}
|
|
||||||
event, err := eventutil.BuildEvent(
|
|
||||||
ctx, // the request context
|
|
||||||
&eb, // the template leave event
|
|
||||||
r.Cfg.Matrix, // the server configuration
|
|
||||||
time.Now(), // the event timestamp to use
|
|
||||||
r, // the roomserver API to use
|
|
||||||
&buildRes, // the query response
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("eventutil.BuildEvent: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Give our leave event to the roomserver input stream. The
|
|
||||||
// roomserver will process the membership change and notify
|
|
||||||
// downstream automatically.
|
|
||||||
inputReq := api.InputRoomEventsRequest{
|
|
||||||
InputRoomEvents: []api.InputRoomEvent{
|
|
||||||
{
|
|
||||||
Kind: api.KindNew,
|
|
||||||
Event: event.Headered(buildRes.RoomVersion),
|
|
||||||
AuthEventIDs: event.AuthEventIDs(),
|
|
||||||
SendAsServer: string(r.Cfg.Matrix.ServerName),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
inputRes := api.InputRoomEventsResponse{}
|
|
||||||
if err = r.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil {
|
|
||||||
return fmt.Errorf("r.InputRoomEvents: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) performRejectInvite(
|
|
||||||
ctx context.Context,
|
|
||||||
req *api.PerformLeaveRequest,
|
|
||||||
res *api.PerformLeaveResponse, // nolint:unparam
|
|
||||||
senderUser, eventID string,
|
|
||||||
) error {
|
|
||||||
_, domain, err := gomatrixserverlib.SplitID('@', senderUser)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("User ID %q invalid: %w", senderUser, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ask the federation sender to perform a federated leave for us.
|
|
||||||
leaveReq := fsAPI.PerformLeaveRequest{
|
|
||||||
RoomID: req.RoomID,
|
|
||||||
UserID: req.UserID,
|
|
||||||
ServerNames: []gomatrixserverlib.ServerName{domain},
|
|
||||||
}
|
|
||||||
leaveRes := fsAPI.PerformLeaveResponse{}
|
|
||||||
if err := r.fsAPI.PerformLeave(ctx, &leaveReq, &leaveRes); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Withdraw the invite, so that the sync API etc are
|
|
||||||
// notified that we rejected it.
|
|
||||||
return r.WriteOutputEvents(req.RoomID, []api.OutputEvent{
|
|
||||||
{
|
|
||||||
Type: api.OutputTypeRetireInviteEvent,
|
|
||||||
RetireInviteEvent: &api.OutputRetireInviteEvent{
|
|
||||||
EventID: eventID,
|
|
||||||
Membership: "leave",
|
|
||||||
TargetUserID: req.UserID,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) isInvitePending(
|
|
||||||
ctx context.Context,
|
|
||||||
roomID, userID string,
|
|
||||||
) (bool, string, string, error) {
|
|
||||||
// Look up the room NID for the supplied room ID.
|
|
||||||
info, err := r.DB.RoomInfo(ctx, roomID)
|
|
||||||
if err != nil {
|
|
||||||
return false, "", "", fmt.Errorf("r.DB.RoomInfo: %w", err)
|
|
||||||
}
|
|
||||||
if info == nil {
|
|
||||||
return false, "", "", fmt.Errorf("cannot get RoomInfo: unknown room ID %s", roomID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Look up the state key NID for the supplied user ID.
|
|
||||||
targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{userID})
|
|
||||||
if err != nil {
|
|
||||||
return false, "", "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err)
|
|
||||||
}
|
|
||||||
targetUserNID, targetUserFound := targetUserNIDs[userID]
|
|
||||||
if !targetUserFound {
|
|
||||||
return false, "", "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Let's see if we have an event active for the user in the room. If
|
|
||||||
// we do then it will contain a server name that we can direct the
|
|
||||||
// send_leave to.
|
|
||||||
senderUserNIDs, eventIDs, err := r.DB.GetInvitesForUser(ctx, info.RoomNID, targetUserNID)
|
|
||||||
if err != nil {
|
|
||||||
return false, "", "", fmt.Errorf("r.DB.GetInvitesForUser: %w", err)
|
|
||||||
}
|
|
||||||
if len(senderUserNIDs) == 0 {
|
|
||||||
return false, "", "", nil
|
|
||||||
}
|
|
||||||
userNIDToEventID := make(map[types.EventStateKeyNID]string)
|
|
||||||
for i, nid := range senderUserNIDs {
|
|
||||||
userNIDToEventID[nid] = eventIDs[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Look up the user ID from the NID.
|
|
||||||
senderUsers, err := r.DB.EventStateKeys(ctx, senderUserNIDs)
|
|
||||||
if err != nil {
|
|
||||||
return false, "", "", fmt.Errorf("r.DB.EventStateKeys: %w", err)
|
|
||||||
}
|
|
||||||
if len(senderUsers) == 0 {
|
|
||||||
return false, "", "", fmt.Errorf("no senderUsers")
|
|
||||||
}
|
|
||||||
|
|
||||||
senderUser, senderUserFound := senderUsers[senderUserNIDs[0]]
|
|
||||||
if !senderUserFound {
|
|
||||||
return false, "", "", fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers)
|
|
||||||
}
|
|
||||||
|
|
||||||
return true, senderUser, userNIDToEventID[senderUserNIDs[0]], nil
|
|
||||||
}
|
|
|
@ -1,20 +0,0 @@
|
||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) PerformPublish(
|
|
||||||
ctx context.Context,
|
|
||||||
req *api.PerformPublishRequest,
|
|
||||||
res *api.PerformPublishResponse,
|
|
||||||
) {
|
|
||||||
err := r.DB.PublishRoom(ctx, req.RoomID, req.Visibility == "public")
|
|
||||||
if err != nil {
|
|
||||||
res.Error = &api.PerformError{
|
|
||||||
Msg: err.Error(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,960 +0,0 @@
|
||||||
// Copyright 2017 Vector Creations Ltd
|
|
||||||
// Copyright 2018 New Vector Ltd
|
|
||||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/eventutil"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/auth"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/state"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/version"
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
"github.com/matrix-org/util"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
// QueryLatestEventsAndState implements api.RoomserverInternalAPI
|
|
||||||
func (r *RoomserverInternalAPI) QueryLatestEventsAndState(
|
|
||||||
ctx context.Context,
|
|
||||||
request *api.QueryLatestEventsAndStateRequest,
|
|
||||||
response *api.QueryLatestEventsAndStateResponse,
|
|
||||||
) error {
|
|
||||||
roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID)
|
|
||||||
if err != nil {
|
|
||||||
response.RoomExists = false
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
roomState := state.NewStateResolution(r.DB)
|
|
||||||
|
|
||||||
info, err := r.DB.RoomInfo(ctx, request.RoomID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if info.IsStub {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
response.RoomExists = true
|
|
||||||
response.RoomVersion = roomVersion
|
|
||||||
|
|
||||||
var currentStateSnapshotNID types.StateSnapshotNID
|
|
||||||
response.LatestEvents, currentStateSnapshotNID, response.Depth, err =
|
|
||||||
r.DB.LatestEventIDs(ctx, info.RoomNID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var stateEntries []types.StateEntry
|
|
||||||
if len(request.StateToFetch) == 0 {
|
|
||||||
// Look up all room state.
|
|
||||||
stateEntries, err = roomState.LoadStateAtSnapshot(
|
|
||||||
ctx, currentStateSnapshotNID,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
// Look up the current state for the requested tuples.
|
|
||||||
stateEntries, err = roomState.LoadStateAtSnapshotForStringTuples(
|
|
||||||
ctx, currentStateSnapshotNID, request.StateToFetch,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
stateEvents, err := r.loadStateEvents(ctx, stateEntries)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, event := range stateEvents {
|
|
||||||
response.StateEvents = append(response.StateEvents, event.Headered(roomVersion))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryStateAfterEvents implements api.RoomserverInternalAPI
|
|
||||||
func (r *RoomserverInternalAPI) QueryStateAfterEvents(
|
|
||||||
ctx context.Context,
|
|
||||||
request *api.QueryStateAfterEventsRequest,
|
|
||||||
response *api.QueryStateAfterEventsResponse,
|
|
||||||
) error {
|
|
||||||
roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID)
|
|
||||||
if err != nil {
|
|
||||||
response.RoomExists = false
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
roomState := state.NewStateResolution(r.DB)
|
|
||||||
|
|
||||||
info, err := r.DB.RoomInfo(ctx, request.RoomID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if info.IsStub {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
response.RoomExists = true
|
|
||||||
response.RoomVersion = roomVersion
|
|
||||||
|
|
||||||
prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs)
|
|
||||||
if err != nil {
|
|
||||||
switch err.(type) {
|
|
||||||
case types.MissingEventError:
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
response.PrevEventsExist = true
|
|
||||||
|
|
||||||
// Look up the currrent state for the requested tuples.
|
|
||||||
stateEntries, err := roomState.LoadStateAfterEventsForStringTuples(
|
|
||||||
ctx, info.RoomNID, prevStates, request.StateToFetch,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
stateEvents, err := r.loadStateEvents(ctx, stateEntries)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, event := range stateEvents {
|
|
||||||
response.StateEvents = append(response.StateEvents, event.Headered(roomVersion))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryEventsByID implements api.RoomserverInternalAPI
|
|
||||||
func (r *RoomserverInternalAPI) QueryEventsByID(
|
|
||||||
ctx context.Context,
|
|
||||||
request *api.QueryEventsByIDRequest,
|
|
||||||
response *api.QueryEventsByIDResponse,
|
|
||||||
) error {
|
|
||||||
eventNIDMap, err := r.DB.EventNIDs(ctx, request.EventIDs)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var eventNIDs []types.EventNID
|
|
||||||
for _, nid := range eventNIDMap {
|
|
||||||
eventNIDs = append(eventNIDs, nid)
|
|
||||||
}
|
|
||||||
|
|
||||||
events, err := r.loadEvents(ctx, eventNIDs)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, event := range events {
|
|
||||||
roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, event.RoomID())
|
|
||||||
if verr != nil {
|
|
||||||
return verr
|
|
||||||
}
|
|
||||||
|
|
||||||
response.Events = append(response.Events, event.Headered(roomVersion))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) loadStateEvents(
|
|
||||||
ctx context.Context, stateEntries []types.StateEntry,
|
|
||||||
) ([]gomatrixserverlib.Event, error) {
|
|
||||||
eventNIDs := make([]types.EventNID, len(stateEntries))
|
|
||||||
for i := range stateEntries {
|
|
||||||
eventNIDs[i] = stateEntries[i].EventNID
|
|
||||||
}
|
|
||||||
return r.loadEvents(ctx, eventNIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) loadEvents(
|
|
||||||
ctx context.Context, eventNIDs []types.EventNID,
|
|
||||||
) ([]gomatrixserverlib.Event, error) {
|
|
||||||
stateEvents, err := r.DB.Events(ctx, eventNIDs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
result := make([]gomatrixserverlib.Event, len(stateEvents))
|
|
||||||
for i := range stateEvents {
|
|
||||||
result[i] = stateEvents[i].Event
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryMembershipForUser implements api.RoomserverInternalAPI
|
|
||||||
func (r *RoomserverInternalAPI) QueryMembershipForUser(
|
|
||||||
ctx context.Context,
|
|
||||||
request *api.QueryMembershipForUserRequest,
|
|
||||||
response *api.QueryMembershipForUserResponse,
|
|
||||||
) error {
|
|
||||||
info, err := r.DB.RoomInfo(ctx, request.RoomID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if membershipEventNID == 0 {
|
|
||||||
response.HasBeenInRoom = false
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
response.IsInRoom = stillInRoom
|
|
||||||
response.HasBeenInRoom = true
|
|
||||||
|
|
||||||
evs, err := r.DB.Events(ctx, []types.EventNID{membershipEventNID})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if len(evs) != 1 {
|
|
||||||
return fmt.Errorf("failed to load membership event for event NID %d", membershipEventNID)
|
|
||||||
}
|
|
||||||
|
|
||||||
response.EventID = evs[0].EventID()
|
|
||||||
response.Membership, err = evs[0].Membership()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryMembershipsForRoom implements api.RoomserverInternalAPI
|
|
||||||
func (r *RoomserverInternalAPI) QueryMembershipsForRoom(
|
|
||||||
ctx context.Context,
|
|
||||||
request *api.QueryMembershipsForRoomRequest,
|
|
||||||
response *api.QueryMembershipsForRoomResponse,
|
|
||||||
) error {
|
|
||||||
info, err := r.DB.RoomInfo(ctx, request.RoomID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if membershipEventNID == 0 {
|
|
||||||
response.HasBeenInRoom = false
|
|
||||||
response.JoinEvents = nil
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
response.HasBeenInRoom = true
|
|
||||||
response.JoinEvents = []gomatrixserverlib.ClientEvent{}
|
|
||||||
|
|
||||||
var events []types.Event
|
|
||||||
var stateEntries []types.StateEntry
|
|
||||||
if stillInRoom {
|
|
||||||
var eventNIDs []types.EventNID
|
|
||||||
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, false)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
events, err = r.DB.Events(ctx, eventNIDs)
|
|
||||||
} else {
|
|
||||||
stateEntries, err = stateBeforeEvent(ctx, r.DB, membershipEventNID)
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
events, err = getMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, event := range events {
|
|
||||||
clientEvent := gomatrixserverlib.ToClientEvent(event.Event, gomatrixserverlib.FormatAll)
|
|
||||||
response.JoinEvents = append(response.JoinEvents, clientEvent)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func stateBeforeEvent(ctx context.Context, db storage.Database, eventNID types.EventNID) ([]types.StateEntry, error) {
|
|
||||||
roomState := state.NewStateResolution(db)
|
|
||||||
// Lookup the event NID
|
|
||||||
eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
eventIDs := []string{eIDs[eventNID]}
|
|
||||||
|
|
||||||
prevState, err := db.StateAtEventIDs(ctx, eventIDs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fetch the state as it was when this event was fired
|
|
||||||
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getMembershipsAtState filters the state events to
|
|
||||||
// only keep the "m.room.member" events with a "join" membership. These events are returned.
|
|
||||||
// Returns an error if there was an issue fetching the events.
|
|
||||||
func getMembershipsAtState(
|
|
||||||
ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool,
|
|
||||||
) ([]types.Event, error) {
|
|
||||||
|
|
||||||
var eventNIDs []types.EventNID
|
|
||||||
for _, entry := range stateEntries {
|
|
||||||
// Filter the events to retrieve to only keep the membership events
|
|
||||||
if entry.EventTypeNID == types.MRoomMemberNID {
|
|
||||||
eventNIDs = append(eventNIDs, entry.EventNID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get all of the events in this state
|
|
||||||
stateEvents, err := db.Events(ctx, eventNIDs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !joinedOnly {
|
|
||||||
return stateEvents, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter the events to only keep the "join" membership events
|
|
||||||
var events []types.Event
|
|
||||||
for _, event := range stateEvents {
|
|
||||||
membership, err := event.Membership()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if membership == gomatrixserverlib.Join {
|
|
||||||
events = append(events, event)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return events, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI
|
|
||||||
func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent(
|
|
||||||
ctx context.Context,
|
|
||||||
request *api.QueryServerAllowedToSeeEventRequest,
|
|
||||||
response *api.QueryServerAllowedToSeeEventResponse,
|
|
||||||
) (err error) {
|
|
||||||
events, err := r.DB.EventsFromIDs(ctx, []string{request.EventID})
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(events) == 0 {
|
|
||||||
response.AllowedToSeeEvent = false // event doesn't exist so not allowed to see
|
|
||||||
return
|
|
||||||
}
|
|
||||||
isServerInRoom, err := r.isServerCurrentlyInRoom(ctx, request.ServerName, events[0].RoomID())
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
response.AllowedToSeeEvent, err = r.checkServerAllowedToSeeEvent(
|
|
||||||
ctx, request.EventID, request.ServerName, isServerInRoom,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) checkServerAllowedToSeeEvent(
|
|
||||||
ctx context.Context, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
|
|
||||||
) (bool, error) {
|
|
||||||
roomState := state.NewStateResolution(r.DB)
|
|
||||||
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: We probably want to make it so that we don't have to pull
|
|
||||||
// out all the state if possible.
|
|
||||||
stateAtEvent, err := r.loadStateEvents(ctx, stateEntries)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryMissingEvents implements api.RoomserverInternalAPI
|
|
||||||
func (r *RoomserverInternalAPI) QueryMissingEvents(
|
|
||||||
ctx context.Context,
|
|
||||||
request *api.QueryMissingEventsRequest,
|
|
||||||
response *api.QueryMissingEventsResponse,
|
|
||||||
) error {
|
|
||||||
var front []string
|
|
||||||
eventsToFilter := make(map[string]bool, len(request.LatestEvents))
|
|
||||||
visited := make(map[string]bool, request.Limit) // request.Limit acts as a hint to size.
|
|
||||||
for _, id := range request.EarliestEvents {
|
|
||||||
visited[id] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, id := range request.LatestEvents {
|
|
||||||
if !visited[id] {
|
|
||||||
front = append(front, id)
|
|
||||||
eventsToFilter[id] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
loadedEvents, err := r.loadEvents(ctx, resultNIDs)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
response.Events = make([]gomatrixserverlib.HeaderedEvent, 0, len(loadedEvents)-len(eventsToFilter))
|
|
||||||
for _, event := range loadedEvents {
|
|
||||||
if !eventsToFilter[event.EventID()] {
|
|
||||||
roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, event.RoomID())
|
|
||||||
if verr != nil {
|
|
||||||
return verr
|
|
||||||
}
|
|
||||||
|
|
||||||
response.Events = append(response.Events, event.Headered(roomVersion))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// PerformBackfill implements api.RoomServerQueryAPI
|
|
||||||
func (r *RoomserverInternalAPI) PerformBackfill(
|
|
||||||
ctx context.Context,
|
|
||||||
request *api.PerformBackfillRequest,
|
|
||||||
response *api.PerformBackfillResponse,
|
|
||||||
) error {
|
|
||||||
// if we are requesting the backfill then we need to do a federation hit
|
|
||||||
// TODO: we could be more sensible and fetch as many events we already have then request the rest
|
|
||||||
// which is what the syncapi does already.
|
|
||||||
if request.ServerName == r.ServerName {
|
|
||||||
return r.backfillViaFederation(ctx, request, response)
|
|
||||||
}
|
|
||||||
// someone else is requesting the backfill, try to service their request.
|
|
||||||
var err error
|
|
||||||
var front []string
|
|
||||||
|
|
||||||
// The limit defines the maximum number of events to retrieve, so it also
|
|
||||||
// defines the highest number of elements in the map below.
|
|
||||||
visited := make(map[string]bool, request.Limit)
|
|
||||||
|
|
||||||
// this will include these events which is what we want
|
|
||||||
front = request.PrevEventIDs()
|
|
||||||
|
|
||||||
// Scan the event tree for events to send back.
|
|
||||||
resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve events from the list that was filled previously.
|
|
||||||
var loadedEvents []gomatrixserverlib.Event
|
|
||||||
loadedEvents, err = r.loadEvents(ctx, resultNIDs)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, event := range loadedEvents {
|
|
||||||
roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, event.RoomID())
|
|
||||||
if verr != nil {
|
|
||||||
return verr
|
|
||||||
}
|
|
||||||
|
|
||||||
response.Events = append(response.Events, event.Headered(roomVersion))
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) backfillViaFederation(ctx context.Context, req *api.PerformBackfillRequest, res *api.PerformBackfillResponse) error {
|
|
||||||
roomVer, err := r.DB.GetRoomVersionForRoom(ctx, req.RoomID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("backfillViaFederation: unknown room version for room %s : %w", req.RoomID, err)
|
|
||||||
}
|
|
||||||
requester := newBackfillRequester(r.DB, r.FedClient, r.ServerName, req.BackwardsExtremities)
|
|
||||||
// Request 100 items regardless of what the query asks for.
|
|
||||||
// We don't want to go much higher than this.
|
|
||||||
// We can't honour exactly the limit as some sytests rely on requesting more for tests to pass
|
|
||||||
// (so we don't need to hit /state_ids which the test has no listener for)
|
|
||||||
// Specifically the test "Outbound federation can backfill events"
|
|
||||||
events, err := gomatrixserverlib.RequestBackfill(
|
|
||||||
ctx, requester,
|
|
||||||
r.KeyRing, req.RoomID, roomVer, req.PrevEventIDs(), 100)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events))
|
|
||||||
|
|
||||||
// persist these new events - auth checks have already been done
|
|
||||||
roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ev := range backfilledEventMap {
|
|
||||||
// now add state for these events
|
|
||||||
stateIDs, ok := requester.eventIDToBeforeStateIDs[ev.EventID()]
|
|
||||||
if !ok {
|
|
||||||
// this should be impossible as all events returned must have pass Step 5 of the PDU checks
|
|
||||||
// which requires a list of state IDs.
|
|
||||||
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to find state IDs for event which passed auth checks")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
var entries []types.StateEntry
|
|
||||||
if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs); err != nil {
|
|
||||||
// attempt to fetch the missing events
|
|
||||||
r.fetchAndStoreMissingEvents(ctx, roomVer, requester, stateIDs)
|
|
||||||
// try again
|
|
||||||
entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs)
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to get state entries for event")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var beforeStateSnapshotNID types.StateSnapshotNID
|
|
||||||
if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil {
|
|
||||||
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err = r.DB.SetState(ctx, ev.EventNID, beforeStateSnapshotNID); err != nil {
|
|
||||||
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist snapshot nid")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: update backwards extremities, as that should be moved from syncapi to roomserver at some point.
|
|
||||||
|
|
||||||
res.Events = events
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) isServerCurrentlyInRoom(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) {
|
|
||||||
info, err := r.DB.RoomInfo(ctx, roomID)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
if info == nil {
|
|
||||||
return false, fmt.Errorf("unknown room %s", roomID)
|
|
||||||
}
|
|
||||||
|
|
||||||
eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
events, err := r.DB.Events(ctx, eventNIDs)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
gmslEvents := make([]gomatrixserverlib.Event, len(events))
|
|
||||||
for i := range events {
|
|
||||||
gmslEvents[i] = events[i].Event
|
|
||||||
}
|
|
||||||
return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just
|
|
||||||
// best effort.
|
|
||||||
func (r *RoomserverInternalAPI) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion,
|
|
||||||
backfillRequester *backfillRequester, stateIDs []string) {
|
|
||||||
|
|
||||||
servers := backfillRequester.servers
|
|
||||||
|
|
||||||
// work out which are missing
|
|
||||||
nidMap, err := r.DB.EventNIDs(ctx, stateIDs)
|
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(ctx).WithError(err).Warn("cannot query missing events")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
missingMap := make(map[string]*gomatrixserverlib.HeaderedEvent) // id -> event
|
|
||||||
for _, id := range stateIDs {
|
|
||||||
if _, ok := nidMap[id]; !ok {
|
|
||||||
missingMap[id] = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
util.GetLogger(ctx).Infof("Fetching %d missing state events (from %d possible servers)", len(missingMap), len(servers))
|
|
||||||
|
|
||||||
// fetch the events from federation. Loop the servers first so if we find one that works we stick with them
|
|
||||||
for _, srv := range servers {
|
|
||||||
for id, ev := range missingMap {
|
|
||||||
if ev != nil {
|
|
||||||
continue // already found
|
|
||||||
}
|
|
||||||
logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id)
|
|
||||||
res, err := r.FedClient.GetEvent(ctx, srv, id)
|
|
||||||
if err != nil {
|
|
||||||
logger.WithError(err).Warn("failed to get event from server")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
|
|
||||||
result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents)
|
|
||||||
if err != nil {
|
|
||||||
logger.WithError(err).Warn("failed to load and verify event")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
logger.Infof("returned %d PDUs which made events %+v", len(res.PDUs), result)
|
|
||||||
for _, res := range result {
|
|
||||||
if res.Error != nil {
|
|
||||||
logger.WithError(err).Warn("event failed PDU checks")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
missingMap[id] = res.Event
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var newEvents []gomatrixserverlib.HeaderedEvent
|
|
||||||
for _, ev := range missingMap {
|
|
||||||
if ev != nil {
|
|
||||||
newEvents = append(newEvents, *ev)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents))
|
|
||||||
persistEvents(ctx, r.DB, newEvents)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Remove this when we have tests to assert correctness of this function
|
|
||||||
// nolint:gocyclo
|
|
||||||
func (r *RoomserverInternalAPI) scanEventTree(
|
|
||||||
ctx context.Context, front []string, visited map[string]bool, limit int,
|
|
||||||
serverName gomatrixserverlib.ServerName,
|
|
||||||
) ([]types.EventNID, error) {
|
|
||||||
var resultNIDs []types.EventNID
|
|
||||||
var err error
|
|
||||||
var allowed bool
|
|
||||||
var events []types.Event
|
|
||||||
var next []string
|
|
||||||
var pre string
|
|
||||||
|
|
||||||
// TODO: add tests for this function to ensure it meets the contract that callers expect (and doc what that is supposed to be)
|
|
||||||
// Currently, callers like PerformBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing
|
|
||||||
// so means that the events in that map will NOT be returned from this function. That is not currently true, resulting in
|
|
||||||
// duplicate events being sent in response to /backfill requests.
|
|
||||||
initialIgnoreList := make(map[string]bool, len(visited))
|
|
||||||
for k, v := range visited {
|
|
||||||
initialIgnoreList[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
resultNIDs = make([]types.EventNID, 0, limit)
|
|
||||||
|
|
||||||
var checkedServerInRoom bool
|
|
||||||
var isServerInRoom bool
|
|
||||||
|
|
||||||
// Loop through the event IDs to retrieve the requested events and go
|
|
||||||
// through the whole tree (up to the provided limit) using the events'
|
|
||||||
// "prev_event" key.
|
|
||||||
BFSLoop:
|
|
||||||
for len(front) > 0 {
|
|
||||||
// Prevent unnecessary allocations: reset the slice only when not empty.
|
|
||||||
if len(next) > 0 {
|
|
||||||
next = make([]string, 0)
|
|
||||||
}
|
|
||||||
// Retrieve the events to process from the database.
|
|
||||||
events, err = r.DB.EventsFromIDs(ctx, front)
|
|
||||||
if err != nil {
|
|
||||||
return resultNIDs, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !checkedServerInRoom && len(events) > 0 {
|
|
||||||
// It's nasty that we have to extract the room ID from an event, but many federation requests
|
|
||||||
// only talk in event IDs, no room IDs at all (!!!)
|
|
||||||
ev := events[0]
|
|
||||||
isServerInRoom, err = r.isServerCurrentlyInRoom(ctx, serverName, ev.RoomID())
|
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.")
|
|
||||||
}
|
|
||||||
checkedServerInRoom = true
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ev := range events {
|
|
||||||
// Break out of the loop if the provided limit is reached.
|
|
||||||
if len(resultNIDs) == limit {
|
|
||||||
break BFSLoop
|
|
||||||
}
|
|
||||||
|
|
||||||
if !initialIgnoreList[ev.EventID()] {
|
|
||||||
// Update the list of events to retrieve.
|
|
||||||
resultNIDs = append(resultNIDs, ev.EventNID)
|
|
||||||
}
|
|
||||||
// Loop through the event's parents.
|
|
||||||
for _, pre = range ev.PrevEventIDs() {
|
|
||||||
// Only add an event to the list of next events to process if it
|
|
||||||
// hasn't been seen before.
|
|
||||||
if !visited[pre] {
|
|
||||||
visited[pre] = true
|
|
||||||
allowed, err = r.checkServerAllowedToSeeEvent(ctx, pre, serverName, isServerInRoom)
|
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
|
|
||||||
"Error checking if allowed to see event",
|
|
||||||
)
|
|
||||||
return resultNIDs, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the event hasn't been seen before and the HS
|
|
||||||
// requesting to retrieve it is allowed to do so, add it to
|
|
||||||
// the list of events to retrieve.
|
|
||||||
if allowed {
|
|
||||||
next = append(next, pre)
|
|
||||||
} else {
|
|
||||||
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).Info("Not allowed to see event")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Repeat the same process with the parent events we just processed.
|
|
||||||
front = next
|
|
||||||
}
|
|
||||||
|
|
||||||
return resultNIDs, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryStateAndAuthChain implements api.RoomserverInternalAPI
|
|
||||||
func (r *RoomserverInternalAPI) QueryStateAndAuthChain(
|
|
||||||
ctx context.Context,
|
|
||||||
request *api.QueryStateAndAuthChainRequest,
|
|
||||||
response *api.QueryStateAndAuthChainResponse,
|
|
||||||
) error {
|
|
||||||
info, err := r.DB.RoomInfo(ctx, request.RoomID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if info.IsStub {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
response.RoomExists = true
|
|
||||||
response.RoomVersion = info.RoomVersion
|
|
||||||
|
|
||||||
stateEvents, err := r.loadStateAtEventIDs(ctx, request.PrevEventIDs)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
response.PrevEventsExist = true
|
|
||||||
|
|
||||||
// add the auth event IDs for the current state events too
|
|
||||||
var authEventIDs []string
|
|
||||||
authEventIDs = append(authEventIDs, request.AuthEventIDs...)
|
|
||||||
for _, se := range stateEvents {
|
|
||||||
authEventIDs = append(authEventIDs, se.AuthEventIDs()...)
|
|
||||||
}
|
|
||||||
authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe
|
|
||||||
|
|
||||||
authEvents, err := getAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if request.ResolveState {
|
|
||||||
if stateEvents, err = state.ResolveConflictsAdhoc(
|
|
||||||
info.RoomVersion, stateEvents, authEvents,
|
|
||||||
); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, event := range stateEvents {
|
|
||||||
response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion))
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, event := range authEvents {
|
|
||||||
response.AuthChainEvents = append(response.AuthChainEvents, event.Headered(info.RoomVersion))
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) loadStateAtEventIDs(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) {
|
|
||||||
roomState := state.NewStateResolution(r.DB)
|
|
||||||
prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs)
|
|
||||||
if err != nil {
|
|
||||||
switch err.(type) {
|
|
||||||
case types.MissingEventError:
|
|
||||||
return nil, nil
|
|
||||||
default:
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Look up the currrent state for the requested tuples.
|
|
||||||
stateEntries, err := roomState.LoadCombinedStateAfterEvents(
|
|
||||||
ctx, prevStates,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.loadStateEvents(ctx, stateEntries)
|
|
||||||
}
|
|
||||||
|
|
||||||
type eventsFromIDs func(context.Context, []string) ([]types.Event, error)
|
|
||||||
|
|
||||||
// getAuthChain fetches the auth chain for the given auth events. An auth chain
|
|
||||||
// is the list of all events that are referenced in the auth_events section, and
|
|
||||||
// all their auth_events, recursively. The returned set of events contain the
|
|
||||||
// given events. Will *not* error if we don't have all auth events.
|
|
||||||
func getAuthChain(
|
|
||||||
ctx context.Context, fn eventsFromIDs, authEventIDs []string,
|
|
||||||
) ([]gomatrixserverlib.Event, error) {
|
|
||||||
// List of event IDs to fetch. On each pass, these events will be requested
|
|
||||||
// from the database and the `eventsToFetch` will be updated with any new
|
|
||||||
// events that we have learned about and need to find. When `eventsToFetch`
|
|
||||||
// is eventually empty, we should have reached the end of the chain.
|
|
||||||
eventsToFetch := authEventIDs
|
|
||||||
authEventsMap := make(map[string]gomatrixserverlib.Event)
|
|
||||||
|
|
||||||
for len(eventsToFetch) > 0 {
|
|
||||||
// Try to retrieve the events from the database.
|
|
||||||
events, err := fn(ctx, eventsToFetch)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// We've now fetched these events so clear out `eventsToFetch`. Soon we may
|
|
||||||
// add newly discovered events to this for the next pass.
|
|
||||||
eventsToFetch = eventsToFetch[:0]
|
|
||||||
|
|
||||||
for _, event := range events {
|
|
||||||
// Store the event in the event map - this prevents us from requesting it
|
|
||||||
// from the database again.
|
|
||||||
authEventsMap[event.EventID()] = event.Event
|
|
||||||
|
|
||||||
// Extract all of the auth events from the newly obtained event. If we
|
|
||||||
// don't already have a record of the event, record it in the list of
|
|
||||||
// events we want to request for the next pass.
|
|
||||||
for _, authEvent := range event.AuthEvents() {
|
|
||||||
if _, ok := authEventsMap[authEvent.EventID]; !ok {
|
|
||||||
eventsToFetch = append(eventsToFetch, authEvent.EventID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We've now retrieved all of the events we can. Flatten them down into an
|
|
||||||
// array and return them.
|
|
||||||
var authEvents []gomatrixserverlib.Event
|
|
||||||
for _, event := range authEventsMap {
|
|
||||||
authEvents = append(authEvents, event)
|
|
||||||
}
|
|
||||||
|
|
||||||
return authEvents, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func persistEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) {
|
|
||||||
var roomNID types.RoomNID
|
|
||||||
backfilledEventMap := make(map[string]types.Event)
|
|
||||||
for j, ev := range events {
|
|
||||||
nidMap, err := db.EventNIDs(ctx, ev.AuthEventIDs())
|
|
||||||
if err != nil { // this shouldn't happen as RequestBackfill already found them
|
|
||||||
logrus.WithError(err).WithField("auth_events", ev.AuthEventIDs()).Error("Failed to find one or more auth events")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
authNids := make([]types.EventNID, len(nidMap))
|
|
||||||
i := 0
|
|
||||||
for _, nid := range nidMap {
|
|
||||||
authNids[i] = nid
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
var stateAtEvent types.StateAtEvent
|
|
||||||
var redactedEventID string
|
|
||||||
var redactionEvent *gomatrixserverlib.Event
|
|
||||||
roomNID, stateAtEvent, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids)
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// If storing this event results in it being redacted, then do so.
|
|
||||||
// It's also possible for this event to be a redaction which results in another event being
|
|
||||||
// redacted, which we don't care about since we aren't returning it in this backfill.
|
|
||||||
if redactedEventID == ev.EventID() {
|
|
||||||
eventToRedact := ev.Unwrap()
|
|
||||||
redactedEvent, err := eventutil.RedactEvent(redactionEvent, &eventToRedact)
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ev = redactedEvent.Headered(ev.RoomVersion)
|
|
||||||
events[j] = ev
|
|
||||||
}
|
|
||||||
backfilledEventMap[ev.EventID()] = types.Event{
|
|
||||||
EventNID: stateAtEvent.StateEntry.EventNID,
|
|
||||||
Event: ev.Unwrap(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return roomNID, backfilledEventMap
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryRoomVersionCapabilities implements api.RoomserverInternalAPI
|
|
||||||
func (r *RoomserverInternalAPI) QueryRoomVersionCapabilities(
|
|
||||||
ctx context.Context,
|
|
||||||
request *api.QueryRoomVersionCapabilitiesRequest,
|
|
||||||
response *api.QueryRoomVersionCapabilitiesResponse,
|
|
||||||
) error {
|
|
||||||
response.DefaultRoomVersion = version.DefaultRoomVersion()
|
|
||||||
response.AvailableRoomVersions = make(map[gomatrixserverlib.RoomVersion]string)
|
|
||||||
for v, desc := range version.SupportedRoomVersions() {
|
|
||||||
if desc.Stable {
|
|
||||||
response.AvailableRoomVersions[v] = "stable"
|
|
||||||
} else {
|
|
||||||
response.AvailableRoomVersions[v] = "unstable"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryRoomVersionCapabilities implements api.RoomserverInternalAPI
|
|
||||||
func (r *RoomserverInternalAPI) QueryRoomVersionForRoom(
|
|
||||||
ctx context.Context,
|
|
||||||
request *api.QueryRoomVersionForRoomRequest,
|
|
||||||
response *api.QueryRoomVersionForRoomResponse,
|
|
||||||
) error {
|
|
||||||
if roomVersion, ok := r.Cache.GetRoomVersion(request.RoomID); ok {
|
|
||||||
response.RoomVersion = roomVersion
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
response.RoomVersion = roomVersion
|
|
||||||
r.Cache.StoreRoomVersion(request.RoomID, response.RoomVersion)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RoomserverInternalAPI) QueryPublishedRooms(
|
|
||||||
ctx context.Context,
|
|
||||||
req *api.QueryPublishedRoomsRequest,
|
|
||||||
res *api.QueryPublishedRoomsResponse,
|
|
||||||
) error {
|
|
||||||
rooms, err := r.DB.GetPublishedRooms(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
res.RoomIDs = rooms
|
|
||||||
return nil
|
|
||||||
}
|
|
602
roomserver/internal/query/query.go
Normal file
602
roomserver/internal/query/query.go
Normal file
|
@ -0,0 +1,602 @@
|
||||||
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package query
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/acls"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/state"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/version"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Queryer struct {
|
||||||
|
DB storage.Database
|
||||||
|
Cache caching.RoomServerCaches
|
||||||
|
ServerACLs *acls.ServerACLs
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryLatestEventsAndState implements api.RoomserverInternalAPI
|
||||||
|
func (r *Queryer) QueryLatestEventsAndState(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryLatestEventsAndStateRequest,
|
||||||
|
response *api.QueryLatestEventsAndStateResponse,
|
||||||
|
) error {
|
||||||
|
return helpers.QueryLatestEventsAndState(ctx, r.DB, request, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryStateAfterEvents implements api.RoomserverInternalAPI
|
||||||
|
func (r *Queryer) QueryStateAfterEvents(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryStateAfterEventsRequest,
|
||||||
|
response *api.QueryStateAfterEventsResponse,
|
||||||
|
) error {
|
||||||
|
info, err := r.DB.RoomInfo(ctx, request.RoomID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if info == nil || info.IsStub {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
roomState := state.NewStateResolution(r.DB, *info)
|
||||||
|
response.RoomExists = true
|
||||||
|
response.RoomVersion = info.RoomVersion
|
||||||
|
|
||||||
|
prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs)
|
||||||
|
if err != nil {
|
||||||
|
switch err.(type) {
|
||||||
|
case types.MissingEventError:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response.PrevEventsExist = true
|
||||||
|
|
||||||
|
// Look up the currrent state for the requested tuples.
|
||||||
|
stateEntries, err := roomState.LoadStateAfterEventsForStringTuples(
|
||||||
|
ctx, prevStates, request.StateToFetch,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, event := range stateEvents {
|
||||||
|
response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryEventsByID implements api.RoomserverInternalAPI
|
||||||
|
func (r *Queryer) QueryEventsByID(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryEventsByIDRequest,
|
||||||
|
response *api.QueryEventsByIDResponse,
|
||||||
|
) error {
|
||||||
|
eventNIDMap, err := r.DB.EventNIDs(ctx, request.EventIDs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var eventNIDs []types.EventNID
|
||||||
|
for _, nid := range eventNIDMap {
|
||||||
|
eventNIDs = append(eventNIDs, nid)
|
||||||
|
}
|
||||||
|
|
||||||
|
events, err := helpers.LoadEvents(ctx, r.DB, eventNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, event := range events {
|
||||||
|
roomVersion, verr := r.roomVersion(event.RoomID())
|
||||||
|
if verr != nil {
|
||||||
|
return verr
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Events = append(response.Events, event.Headered(roomVersion))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryMembershipForUser implements api.RoomserverInternalAPI
|
||||||
|
func (r *Queryer) QueryMembershipForUser(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryMembershipForUserRequest,
|
||||||
|
response *api.QueryMembershipForUserResponse,
|
||||||
|
) error {
|
||||||
|
info, err := r.DB.RoomInfo(ctx, request.RoomID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, info.RoomNID, request.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if membershipEventNID == 0 {
|
||||||
|
response.HasBeenInRoom = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
response.IsInRoom = stillInRoom
|
||||||
|
response.HasBeenInRoom = true
|
||||||
|
|
||||||
|
evs, err := r.DB.Events(ctx, []types.EventNID{membershipEventNID})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(evs) != 1 {
|
||||||
|
return fmt.Errorf("failed to load membership event for event NID %d", membershipEventNID)
|
||||||
|
}
|
||||||
|
|
||||||
|
response.EventID = evs[0].EventID()
|
||||||
|
response.Membership, err = evs[0].Membership()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryMembershipsForRoom implements api.RoomserverInternalAPI
|
||||||
|
func (r *Queryer) QueryMembershipsForRoom(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryMembershipsForRoomRequest,
|
||||||
|
response *api.QueryMembershipsForRoomResponse,
|
||||||
|
) error {
|
||||||
|
info, err := r.DB.RoomInfo(ctx, request.RoomID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, info.RoomNID, request.Sender)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if membershipEventNID == 0 {
|
||||||
|
response.HasBeenInRoom = false
|
||||||
|
response.JoinEvents = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
response.HasBeenInRoom = true
|
||||||
|
response.JoinEvents = []gomatrixserverlib.ClientEvent{}
|
||||||
|
|
||||||
|
var events []types.Event
|
||||||
|
var stateEntries []types.StateEntry
|
||||||
|
if stillInRoom {
|
||||||
|
var eventNIDs []types.EventNID
|
||||||
|
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, request.JoinedOnly, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
events, err = r.DB.Events(ctx, eventNIDs)
|
||||||
|
} else {
|
||||||
|
stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, *info, membershipEventNID)
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
events, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, event := range events {
|
||||||
|
clientEvent := gomatrixserverlib.ToClientEvent(event.Event, gomatrixserverlib.FormatAll)
|
||||||
|
response.JoinEvents = append(response.JoinEvents, clientEvent)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI
|
||||||
|
func (r *Queryer) QueryServerAllowedToSeeEvent(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryServerAllowedToSeeEventRequest,
|
||||||
|
response *api.QueryServerAllowedToSeeEventResponse,
|
||||||
|
) (err error) {
|
||||||
|
events, err := r.DB.EventsFromIDs(ctx, []string{request.EventID})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(events) == 0 {
|
||||||
|
response.AllowedToSeeEvent = false // event doesn't exist so not allowed to see
|
||||||
|
return
|
||||||
|
}
|
||||||
|
roomID := events[0].RoomID()
|
||||||
|
isServerInRoom, err := helpers.IsServerCurrentlyInRoom(ctx, r.DB, request.ServerName, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
info, err := r.DB.RoomInfo(ctx, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if info == nil {
|
||||||
|
return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID)
|
||||||
|
}
|
||||||
|
response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent(
|
||||||
|
ctx, r.DB, *info, request.EventID, request.ServerName, isServerInRoom,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryMissingEvents implements api.RoomserverInternalAPI
|
||||||
|
// nolint:gocyclo
|
||||||
|
func (r *Queryer) QueryMissingEvents(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryMissingEventsRequest,
|
||||||
|
response *api.QueryMissingEventsResponse,
|
||||||
|
) error {
|
||||||
|
var front []string
|
||||||
|
eventsToFilter := make(map[string]bool, len(request.LatestEvents))
|
||||||
|
visited := make(map[string]bool, request.Limit) // request.Limit acts as a hint to size.
|
||||||
|
for _, id := range request.EarliestEvents {
|
||||||
|
visited[id] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, id := range request.LatestEvents {
|
||||||
|
if !visited[id] {
|
||||||
|
front = append(front, id)
|
||||||
|
eventsToFilter[id] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
events, err := r.DB.EventsFromIDs(ctx, front)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(events) == 0 {
|
||||||
|
return nil // we are missing the events being asked to search from, give up.
|
||||||
|
}
|
||||||
|
info, err := r.DB.RoomInfo(ctx, events[0].RoomID())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if info == nil || info.IsStub {
|
||||||
|
return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID())
|
||||||
|
}
|
||||||
|
|
||||||
|
resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, *info, front, visited, request.Limit, request.ServerName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
loadedEvents, err := helpers.LoadEvents(ctx, r.DB, resultNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Events = make([]gomatrixserverlib.HeaderedEvent, 0, len(loadedEvents)-len(eventsToFilter))
|
||||||
|
for _, event := range loadedEvents {
|
||||||
|
if !eventsToFilter[event.EventID()] {
|
||||||
|
roomVersion, verr := r.roomVersion(event.RoomID())
|
||||||
|
if verr != nil {
|
||||||
|
return verr
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Events = append(response.Events, event.Headered(roomVersion))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryStateAndAuthChain implements api.RoomserverInternalAPI
|
||||||
|
func (r *Queryer) QueryStateAndAuthChain(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryStateAndAuthChainRequest,
|
||||||
|
response *api.QueryStateAndAuthChainResponse,
|
||||||
|
) error {
|
||||||
|
info, err := r.DB.RoomInfo(ctx, request.RoomID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if info == nil || info.IsStub {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
response.RoomExists = true
|
||||||
|
response.RoomVersion = info.RoomVersion
|
||||||
|
|
||||||
|
stateEvents, err := r.loadStateAtEventIDs(ctx, *info, request.PrevEventIDs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
response.PrevEventsExist = true
|
||||||
|
|
||||||
|
// add the auth event IDs for the current state events too
|
||||||
|
var authEventIDs []string
|
||||||
|
authEventIDs = append(authEventIDs, request.AuthEventIDs...)
|
||||||
|
for _, se := range stateEvents {
|
||||||
|
authEventIDs = append(authEventIDs, se.AuthEventIDs()...)
|
||||||
|
}
|
||||||
|
authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe
|
||||||
|
|
||||||
|
authEvents, err := getAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.ResolveState {
|
||||||
|
if stateEvents, err = state.ResolveConflictsAdhoc(
|
||||||
|
info.RoomVersion, stateEvents, authEvents,
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, event := range stateEvents {
|
||||||
|
response.StateEvents = append(response.StateEvents, event.Headered(info.RoomVersion))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, event := range authEvents {
|
||||||
|
response.AuthChainEvents = append(response.AuthChainEvents, event.Headered(info.RoomVersion))
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.Event, error) {
|
||||||
|
roomState := state.NewStateResolution(r.DB, roomInfo)
|
||||||
|
prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs)
|
||||||
|
if err != nil {
|
||||||
|
switch err.(type) {
|
||||||
|
case types.MissingEventError:
|
||||||
|
return nil, nil
|
||||||
|
default:
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up the currrent state for the requested tuples.
|
||||||
|
stateEntries, err := roomState.LoadCombinedStateAfterEvents(
|
||||||
|
ctx, prevStates,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return helpers.LoadStateEvents(ctx, r.DB, stateEntries)
|
||||||
|
}
|
||||||
|
|
||||||
|
type eventsFromIDs func(context.Context, []string) ([]types.Event, error)
|
||||||
|
|
||||||
|
// getAuthChain fetches the auth chain for the given auth events. An auth chain
|
||||||
|
// is the list of all events that are referenced in the auth_events section, and
|
||||||
|
// all their auth_events, recursively. The returned set of events contain the
|
||||||
|
// given events. Will *not* error if we don't have all auth events.
|
||||||
|
func getAuthChain(
|
||||||
|
ctx context.Context, fn eventsFromIDs, authEventIDs []string,
|
||||||
|
) ([]gomatrixserverlib.Event, error) {
|
||||||
|
// List of event IDs to fetch. On each pass, these events will be requested
|
||||||
|
// from the database and the `eventsToFetch` will be updated with any new
|
||||||
|
// events that we have learned about and need to find. When `eventsToFetch`
|
||||||
|
// is eventually empty, we should have reached the end of the chain.
|
||||||
|
eventsToFetch := authEventIDs
|
||||||
|
authEventsMap := make(map[string]gomatrixserverlib.Event)
|
||||||
|
|
||||||
|
for len(eventsToFetch) > 0 {
|
||||||
|
// Try to retrieve the events from the database.
|
||||||
|
events, err := fn(ctx, eventsToFetch)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// We've now fetched these events so clear out `eventsToFetch`. Soon we may
|
||||||
|
// add newly discovered events to this for the next pass.
|
||||||
|
eventsToFetch = eventsToFetch[:0]
|
||||||
|
|
||||||
|
for _, event := range events {
|
||||||
|
// Store the event in the event map - this prevents us from requesting it
|
||||||
|
// from the database again.
|
||||||
|
authEventsMap[event.EventID()] = event.Event
|
||||||
|
|
||||||
|
// Extract all of the auth events from the newly obtained event. If we
|
||||||
|
// don't already have a record of the event, record it in the list of
|
||||||
|
// events we want to request for the next pass.
|
||||||
|
for _, authEvent := range event.AuthEvents() {
|
||||||
|
if _, ok := authEventsMap[authEvent.EventID]; !ok {
|
||||||
|
eventsToFetch = append(eventsToFetch, authEvent.EventID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We've now retrieved all of the events we can. Flatten them down into an
|
||||||
|
// array and return them.
|
||||||
|
var authEvents []gomatrixserverlib.Event
|
||||||
|
for _, event := range authEventsMap {
|
||||||
|
authEvents = append(authEvents, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
return authEvents, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryRoomVersionCapabilities implements api.RoomserverInternalAPI
|
||||||
|
func (r *Queryer) QueryRoomVersionCapabilities(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryRoomVersionCapabilitiesRequest,
|
||||||
|
response *api.QueryRoomVersionCapabilitiesResponse,
|
||||||
|
) error {
|
||||||
|
response.DefaultRoomVersion = version.DefaultRoomVersion()
|
||||||
|
response.AvailableRoomVersions = make(map[gomatrixserverlib.RoomVersion]string)
|
||||||
|
for v, desc := range version.SupportedRoomVersions() {
|
||||||
|
if desc.Stable {
|
||||||
|
response.AvailableRoomVersions[v] = "stable"
|
||||||
|
} else {
|
||||||
|
response.AvailableRoomVersions[v] = "unstable"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryRoomVersionCapabilities implements api.RoomserverInternalAPI
|
||||||
|
func (r *Queryer) QueryRoomVersionForRoom(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryRoomVersionForRoomRequest,
|
||||||
|
response *api.QueryRoomVersionForRoomResponse,
|
||||||
|
) error {
|
||||||
|
if roomVersion, ok := r.Cache.GetRoomVersion(request.RoomID); ok {
|
||||||
|
response.RoomVersion = roomVersion
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := r.DB.RoomInfo(ctx, request.RoomID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if info == nil {
|
||||||
|
return fmt.Errorf("QueryRoomVersionForRoom: missing room info for room %s", request.RoomID)
|
||||||
|
}
|
||||||
|
response.RoomVersion = info.RoomVersion
|
||||||
|
r.Cache.StoreRoomVersion(request.RoomID, response.RoomVersion)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Queryer) roomVersion(roomID string) (gomatrixserverlib.RoomVersion, error) {
|
||||||
|
var res api.QueryRoomVersionForRoomResponse
|
||||||
|
err := r.QueryRoomVersionForRoom(context.Background(), &api.QueryRoomVersionForRoomRequest{
|
||||||
|
RoomID: roomID,
|
||||||
|
}, &res)
|
||||||
|
return res.RoomVersion, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Queryer) QueryPublishedRooms(
|
||||||
|
ctx context.Context,
|
||||||
|
req *api.QueryPublishedRoomsRequest,
|
||||||
|
res *api.QueryPublishedRoomsResponse,
|
||||||
|
) error {
|
||||||
|
rooms, err := r.DB.GetPublishedRooms(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res.RoomIDs = rooms
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error {
|
||||||
|
res.StateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent)
|
||||||
|
for _, tuple := range req.StateTuples {
|
||||||
|
ev, err := r.DB.GetStateEvent(ctx, req.RoomID, tuple.EventType, tuple.StateKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if ev != nil {
|
||||||
|
res.StateEvents[tuple] = ev
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Queryer) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error {
|
||||||
|
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, req.WantMembership)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res.RoomIDs = roomIDs
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Queryer) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error {
|
||||||
|
users, err := r.DB.GetKnownUsers(ctx, req.UserID, req.SearchString, req.Limit)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, user := range users {
|
||||||
|
res.Users = append(res.Users, authtypes.FullyQualifiedProfile{
|
||||||
|
UserID: user,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Queryer) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error {
|
||||||
|
events, err := r.DB.GetBulkStateContent(ctx, req.RoomIDs, req.StateTuples, req.AllowWildcards)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res.Rooms = make(map[string]map[gomatrixserverlib.StateKeyTuple]string)
|
||||||
|
for _, ev := range events {
|
||||||
|
if res.Rooms[ev.RoomID] == nil {
|
||||||
|
res.Rooms[ev.RoomID] = make(map[gomatrixserverlib.StateKeyTuple]string)
|
||||||
|
}
|
||||||
|
room := res.Rooms[ev.RoomID]
|
||||||
|
room[gomatrixserverlib.StateKeyTuple{
|
||||||
|
EventType: ev.EventType,
|
||||||
|
StateKey: ev.StateKey,
|
||||||
|
}] = ev.ContentValue
|
||||||
|
res.Rooms[ev.RoomID] = room
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
|
||||||
|
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
roomIDs = append(roomIDs, req.IncludeRoomIDs...)
|
||||||
|
excludeMap := make(map[string]bool)
|
||||||
|
for _, roomID := range req.ExcludeRoomIDs {
|
||||||
|
excludeMap[roomID] = true
|
||||||
|
}
|
||||||
|
// filter out excluded rooms
|
||||||
|
j := 0
|
||||||
|
for i := range roomIDs {
|
||||||
|
// move elements to include to the beginning of the slice
|
||||||
|
// then trim elements on the right
|
||||||
|
if !excludeMap[roomIDs[i]] {
|
||||||
|
roomIDs[j] = roomIDs[i]
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
roomIDs = roomIDs[:j]
|
||||||
|
|
||||||
|
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res.UserIDsToCount = users
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Queryer) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse) error {
|
||||||
|
if r.ServerACLs == nil {
|
||||||
|
return errors.New("no server ACL tracking")
|
||||||
|
}
|
||||||
|
res.Banned = r.ServerACLs.IsServerBannedFromRoom(req.ServerName, req.RoomID)
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
// Copyright 2017 Vector Creations Ltd
|
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
|
@ -12,7 +12,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package internal
|
package query
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
|
@ -44,6 +44,12 @@ const (
|
||||||
RoomserverQueryRoomVersionCapabilitiesPath = "/roomserver/queryRoomVersionCapabilities"
|
RoomserverQueryRoomVersionCapabilitiesPath = "/roomserver/queryRoomVersionCapabilities"
|
||||||
RoomserverQueryRoomVersionForRoomPath = "/roomserver/queryRoomVersionForRoom"
|
RoomserverQueryRoomVersionForRoomPath = "/roomserver/queryRoomVersionForRoom"
|
||||||
RoomserverQueryPublishedRoomsPath = "/roomserver/queryPublishedRooms"
|
RoomserverQueryPublishedRoomsPath = "/roomserver/queryPublishedRooms"
|
||||||
|
RoomserverQueryCurrentStatePath = "/roomserver/queryCurrentState"
|
||||||
|
RoomserverQueryRoomsForUserPath = "/roomserver/queryRoomsForUser"
|
||||||
|
RoomserverQueryBulkStateContentPath = "/roomserver/queryBulkStateContent"
|
||||||
|
RoomserverQuerySharedUsersPath = "/roomserver/querySharedUsers"
|
||||||
|
RoomserverQueryKnownUsersPath = "/roomserver/queryKnownUsers"
|
||||||
|
RoomserverQueryServerBannedFromRoomPath = "/roomserver/queryServerBannedFromRoom"
|
||||||
)
|
)
|
||||||
|
|
||||||
type httpRoomserverInternalAPI struct {
|
type httpRoomserverInternalAPI struct {
|
||||||
|
@ -389,3 +395,69 @@ func (h *httpRoomserverInternalAPI) QueryRoomVersionForRoom(
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *httpRoomserverInternalAPI) QueryCurrentState(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryCurrentStateRequest,
|
||||||
|
response *api.QueryCurrentStateResponse,
|
||||||
|
) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryCurrentState")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.roomserverURL + RoomserverQueryCurrentStatePath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpRoomserverInternalAPI) QueryRoomsForUser(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryRoomsForUserRequest,
|
||||||
|
response *api.QueryRoomsForUserResponse,
|
||||||
|
) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomsForUser")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.roomserverURL + RoomserverQueryRoomsForUserPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpRoomserverInternalAPI) QueryBulkStateContent(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryBulkStateContentRequest,
|
||||||
|
response *api.QueryBulkStateContentResponse,
|
||||||
|
) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryBulkStateContent")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.roomserverURL + RoomserverQueryBulkStateContentPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpRoomserverInternalAPI) QuerySharedUsers(
|
||||||
|
ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse,
|
||||||
|
) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySharedUsers")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.roomserverURL + RoomserverQuerySharedUsersPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpRoomserverInternalAPI) QueryKnownUsers(
|
||||||
|
ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse,
|
||||||
|
) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKnownUsers")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.roomserverURL + RoomserverQueryKnownUsersPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpRoomserverInternalAPI) QueryServerBannedFromRoom(
|
||||||
|
ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse,
|
||||||
|
) error {
|
||||||
|
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerBannedFromRoom")
|
||||||
|
defer span.Finish()
|
||||||
|
|
||||||
|
apiURL := h.roomserverURL + RoomserverQueryServerBannedFromRoomPath
|
||||||
|
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||||
|
}
|
||||||
|
|
|
@ -312,4 +312,82 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
|
||||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
internalAPIMux.Handle(RoomserverQueryCurrentStatePath,
|
||||||
|
httputil.MakeInternalAPI("queryCurrentState", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QueryCurrentStateRequest{}
|
||||||
|
response := api.QueryCurrentStateResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := r.QueryCurrentState(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
internalAPIMux.Handle(RoomserverQueryRoomsForUserPath,
|
||||||
|
httputil.MakeInternalAPI("queryRoomsForUser", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QueryRoomsForUserRequest{}
|
||||||
|
response := api.QueryRoomsForUserResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := r.QueryRoomsForUser(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
internalAPIMux.Handle(RoomserverQueryBulkStateContentPath,
|
||||||
|
httputil.MakeInternalAPI("queryBulkStateContent", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QueryBulkStateContentRequest{}
|
||||||
|
response := api.QueryBulkStateContentResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := r.QueryBulkStateContent(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
internalAPIMux.Handle(RoomserverQuerySharedUsersPath,
|
||||||
|
httputil.MakeInternalAPI("querySharedUsers", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QuerySharedUsersRequest{}
|
||||||
|
response := api.QuerySharedUsersResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := r.QuerySharedUsers(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
internalAPIMux.Handle(RoomserverQuerySharedUsersPath,
|
||||||
|
httputil.MakeInternalAPI("queryKnownUsers", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QueryKnownUsersRequest{}
|
||||||
|
response := api.QueryKnownUsersResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := r.QueryKnownUsers(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
internalAPIMux.Handle(RoomserverQueryServerBannedFromRoomPath,
|
||||||
|
httputil.MakeInternalAPI("queryServerBannedFromRoom", func(req *http.Request) util.JSONResponse {
|
||||||
|
request := api.QueryServerBannedFromRoomRequest{}
|
||||||
|
response := api.QueryServerBannedFromRoomResponse{}
|
||||||
|
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||||
|
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||||
|
}
|
||||||
|
if err := r.QueryServerBannedFromRoom(req.Context(), &request, &response); err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||||
|
}),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,7 +38,6 @@ func AddInternalRoutes(router *mux.Router, intAPI api.RoomserverInternalAPI) {
|
||||||
func NewInternalAPI(
|
func NewInternalAPI(
|
||||||
base *setup.BaseDendrite,
|
base *setup.BaseDendrite,
|
||||||
keyRing gomatrixserverlib.JSONVerifier,
|
keyRing gomatrixserverlib.JSONVerifier,
|
||||||
fedClient *gomatrixserverlib.FederationClient,
|
|
||||||
) api.RoomserverInternalAPI {
|
) api.RoomserverInternalAPI {
|
||||||
cfg := &base.Cfg.RoomServer
|
cfg := &base.Cfg.RoomServer
|
||||||
|
|
||||||
|
@ -47,14 +46,8 @@ func NewInternalAPI(
|
||||||
logrus.WithError(err).Panicf("failed to connect to room server db")
|
logrus.WithError(err).Panicf("failed to connect to room server db")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &internal.RoomserverInternalAPI{
|
return internal.NewRoomserverAPI(
|
||||||
DB: roomserverDB,
|
cfg, roomserverDB, base.KafkaProducer, string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputRoomEvent)),
|
||||||
Cfg: cfg,
|
base.Caches, keyRing,
|
||||||
Producer: base.KafkaProducer,
|
)
|
||||||
OutputRoomEventTopic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputRoomEvent)),
|
|
||||||
Cache: base.Caches,
|
|
||||||
ServerName: cfg.Matrix.ServerName,
|
|
||||||
FedClient: fedClient,
|
|
||||||
KeyRing: keyRing,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -112,10 +112,9 @@ func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []js
|
||||||
Cfg: cfg,
|
Cfg: cfg,
|
||||||
}
|
}
|
||||||
|
|
||||||
rsAPI := NewInternalAPI(base, &test.NopJSONVerifier{}, nil)
|
rsAPI := NewInternalAPI(base, &test.NopJSONVerifier{})
|
||||||
hevents := mustLoadEvents(t, ver, events)
|
hevents := mustLoadEvents(t, ver, events)
|
||||||
_, err = api.SendEvents(ctx, rsAPI, hevents, testOrigin, nil)
|
if err = api.SendEvents(ctx, rsAPI, hevents, testOrigin, nil); err != nil {
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to SendEvents: %s", err)
|
t.Errorf("failed to SendEvents: %s", err)
|
||||||
}
|
}
|
||||||
return rsAPI, dp, hevents
|
return rsAPI, dp, hevents
|
||||||
|
|
|
@ -31,12 +31,14 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type StateResolution struct {
|
type StateResolution struct {
|
||||||
db storage.Database
|
db storage.Database
|
||||||
|
roomInfo types.RoomInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStateResolution(db storage.Database) StateResolution {
|
func NewStateResolution(db storage.Database, roomInfo types.RoomInfo) StateResolution {
|
||||||
return StateResolution{
|
return StateResolution{
|
||||||
db: db,
|
db: db,
|
||||||
|
roomInfo: roomInfo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -339,7 +341,7 @@ func (v StateResolution) loadStateAtSnapshotForNumericTuples(
|
||||||
// This is typically the state before an event.
|
// This is typically the state before an event.
|
||||||
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
|
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
|
||||||
func (v StateResolution) LoadStateAfterEventsForStringTuples(
|
func (v StateResolution) LoadStateAfterEventsForStringTuples(
|
||||||
ctx context.Context, roomNID types.RoomNID,
|
ctx context.Context,
|
||||||
prevStates []types.StateAtEvent,
|
prevStates []types.StateAtEvent,
|
||||||
stateKeyTuples []gomatrixserverlib.StateKeyTuple,
|
stateKeyTuples []gomatrixserverlib.StateKeyTuple,
|
||||||
) ([]types.StateEntry, error) {
|
) ([]types.StateEntry, error) {
|
||||||
|
@ -347,24 +349,18 @@ func (v StateResolution) LoadStateAfterEventsForStringTuples(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return v.loadStateAfterEventsForNumericTuples(ctx, roomNID, prevStates, numericTuples)
|
return v.loadStateAfterEventsForNumericTuples(ctx, prevStates, numericTuples)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v StateResolution) loadStateAfterEventsForNumericTuples(
|
func (v StateResolution) loadStateAfterEventsForNumericTuples(
|
||||||
ctx context.Context, roomNID types.RoomNID,
|
ctx context.Context,
|
||||||
prevStates []types.StateAtEvent,
|
prevStates []types.StateAtEvent,
|
||||||
stateKeyTuples []types.StateKeyTuple,
|
stateKeyTuples []types.StateKeyTuple,
|
||||||
) ([]types.StateEntry, error) {
|
) ([]types.StateEntry, error) {
|
||||||
roomVersion, err := v.db.GetRoomVersionForRoomNID(ctx, roomNID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(prevStates) == 1 {
|
if len(prevStates) == 1 {
|
||||||
// Fast path for a single event.
|
// Fast path for a single event.
|
||||||
prevState := prevStates[0]
|
prevState := prevStates[0]
|
||||||
var result []types.StateEntry
|
result, err := v.loadStateAtSnapshotForNumericTuples(
|
||||||
result, err = v.loadStateAtSnapshotForNumericTuples(
|
|
||||||
ctx, prevState.BeforeStateSnapshotNID, stateKeyTuples,
|
ctx, prevState.BeforeStateSnapshotNID, stateKeyTuples,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -403,7 +399,7 @@ func (v StateResolution) loadStateAfterEventsForNumericTuples(
|
||||||
|
|
||||||
// TODO: Add metrics for this as it could take a long time for big rooms
|
// TODO: Add metrics for this as it could take a long time for big rooms
|
||||||
// with large conflicts.
|
// with large conflicts.
|
||||||
fullState, _, _, err := v.calculateStateAfterManyEvents(ctx, roomVersion, prevStates)
|
fullState, _, _, err := v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -527,7 +523,6 @@ func init() {
|
||||||
func (v StateResolution) CalculateAndStoreStateBeforeEvent(
|
func (v StateResolution) CalculateAndStoreStateBeforeEvent(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
event gomatrixserverlib.Event,
|
event gomatrixserverlib.Event,
|
||||||
roomNID types.RoomNID,
|
|
||||||
) (types.StateSnapshotNID, error) {
|
) (types.StateSnapshotNID, error) {
|
||||||
// Load the state at the prev events.
|
// Load the state at the prev events.
|
||||||
prevEventRefs := event.PrevEvents()
|
prevEventRefs := event.PrevEvents()
|
||||||
|
@ -542,14 +537,13 @@ func (v StateResolution) CalculateAndStoreStateBeforeEvent(
|
||||||
}
|
}
|
||||||
|
|
||||||
// The state before this event will be the state after the events that came before it.
|
// The state before this event will be the state after the events that came before it.
|
||||||
return v.CalculateAndStoreStateAfterEvents(ctx, roomNID, prevStates)
|
return v.CalculateAndStoreStateAfterEvents(ctx, prevStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CalculateAndStoreStateAfterEvents finds the room state after the given events.
|
// CalculateAndStoreStateAfterEvents finds the room state after the given events.
|
||||||
// Stores the resulting state in the database and returns a numeric ID for that snapshot.
|
// Stores the resulting state in the database and returns a numeric ID for that snapshot.
|
||||||
func (v StateResolution) CalculateAndStoreStateAfterEvents(
|
func (v StateResolution) CalculateAndStoreStateAfterEvents(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
roomNID types.RoomNID,
|
|
||||||
prevStates []types.StateAtEvent,
|
prevStates []types.StateAtEvent,
|
||||||
) (types.StateSnapshotNID, error) {
|
) (types.StateSnapshotNID, error) {
|
||||||
metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)}
|
metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)}
|
||||||
|
@ -558,7 +552,7 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents(
|
||||||
// 2) There weren't any prev_events for this event so the state is
|
// 2) There weren't any prev_events for this event so the state is
|
||||||
// empty.
|
// empty.
|
||||||
metrics.algorithm = "empty_state"
|
metrics.algorithm = "empty_state"
|
||||||
stateNID, err := v.db.AddState(ctx, roomNID, nil, nil)
|
stateNID, err := v.db.AddState(ctx, v.roomInfo.RoomNID, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("v.db.AddState: %w", err)
|
err = fmt.Errorf("v.db.AddState: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -590,7 +584,7 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents(
|
||||||
// add the state event as a block of size one to the end of the blocks.
|
// add the state event as a block of size one to the end of the blocks.
|
||||||
metrics.algorithm = "single_delta"
|
metrics.algorithm = "single_delta"
|
||||||
stateNID, err := v.db.AddState(
|
stateNID, err := v.db.AddState(
|
||||||
ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry},
|
ctx, v.roomInfo.RoomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("v.db.AddState: %w", err)
|
err = fmt.Errorf("v.db.AddState: %w", err)
|
||||||
|
@ -601,7 +595,7 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents(
|
||||||
// So fall through to calculateAndStoreStateAfterManyEvents
|
// So fall through to calculateAndStoreStateAfterManyEvents
|
||||||
}
|
}
|
||||||
|
|
||||||
stateNID, err := v.calculateAndStoreStateAfterManyEvents(ctx, roomNID, prevStates, metrics)
|
stateNID, err := v.calculateAndStoreStateAfterManyEvents(ctx, v.roomInfo.RoomNID, prevStates, metrics)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("v.calculateAndStoreStateAfterManyEvents: %w", err)
|
return 0, fmt.Errorf("v.calculateAndStoreStateAfterManyEvents: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -624,13 +618,8 @@ func (v StateResolution) calculateAndStoreStateAfterManyEvents(
|
||||||
prevStates []types.StateAtEvent,
|
prevStates []types.StateAtEvent,
|
||||||
metrics calculateStateMetrics,
|
metrics calculateStateMetrics,
|
||||||
) (types.StateSnapshotNID, error) {
|
) (types.StateSnapshotNID, error) {
|
||||||
roomVersion, err := v.db.GetRoomVersionForRoomNID(ctx, roomNID)
|
|
||||||
if err != nil {
|
|
||||||
return metrics.stop(0, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
state, algorithm, conflictLength, err :=
|
state, algorithm, conflictLength, err :=
|
||||||
v.calculateStateAfterManyEvents(ctx, roomVersion, prevStates)
|
v.calculateStateAfterManyEvents(ctx, v.roomInfo.RoomVersion, prevStates)
|
||||||
metrics.algorithm = algorithm
|
metrics.algorithm = algorithm
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return metrics.stop(0, err)
|
return metrics.stop(0, err)
|
||||||
|
|
|
@ -17,6 +17,7 @@ package storage
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/currentstateserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
@ -66,8 +67,6 @@ type Database interface {
|
||||||
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
|
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
|
||||||
// Look up snapshot NID for an event ID string
|
// Look up snapshot NID for an event ID string
|
||||||
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
|
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
|
||||||
// Look up a room version from the room NID.
|
|
||||||
GetRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error)
|
|
||||||
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.
|
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.
|
||||||
StoreEvent(
|
StoreEvent(
|
||||||
ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
|
ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
|
||||||
|
@ -91,7 +90,7 @@ type Database interface {
|
||||||
// The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error.
|
// The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error.
|
||||||
// Returns the latest events in the room and the last eventID sent to the log along with an updater.
|
// Returns the latest events in the room and the last eventID sent to the log along with an updater.
|
||||||
// If this returns an error then no further action is required.
|
// If this returns an error then no further action is required.
|
||||||
GetLatestEventsForUpdate(ctx context.Context, roomNID types.RoomNID) (*shared.LatestEventsUpdater, error)
|
GetLatestEventsForUpdate(ctx context.Context, roomInfo types.RoomInfo) (*shared.LatestEventsUpdater, error)
|
||||||
// Look up event ID by transaction's info.
|
// Look up event ID by transaction's info.
|
||||||
// This is used to determine if the room event is processed/processing already.
|
// This is used to determine if the room event is processed/processing already.
|
||||||
// Returns an empty string if no such event exists.
|
// Returns an empty string if no such event exists.
|
||||||
|
@ -136,10 +135,26 @@ type Database interface {
|
||||||
// not found.
|
// not found.
|
||||||
// Returns an error if the retrieval went wrong.
|
// Returns an error if the retrieval went wrong.
|
||||||
EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error)
|
EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error)
|
||||||
// Look up the room version for a given room.
|
|
||||||
GetRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
|
|
||||||
// Publish or unpublish a room from the room directory.
|
// Publish or unpublish a room from the room directory.
|
||||||
PublishRoom(ctx context.Context, roomID string, publish bool) error
|
PublishRoom(ctx context.Context, roomID string, publish bool) error
|
||||||
// Returns a list of room IDs for rooms which are published.
|
// Returns a list of room IDs for rooms which are published.
|
||||||
GetPublishedRooms(ctx context.Context) ([]string, error)
|
GetPublishedRooms(ctx context.Context) ([]string, error)
|
||||||
|
|
||||||
|
// TODO: factor out - from currentstateserver
|
||||||
|
|
||||||
|
// GetStateEvent returns the state event of a given type for a given room with a given state key
|
||||||
|
// If no event could be found, returns nil
|
||||||
|
// If there was an issue during the retrieval, returns an error
|
||||||
|
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error)
|
||||||
|
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
|
||||||
|
GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error)
|
||||||
|
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
|
||||||
|
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
|
||||||
|
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
|
||||||
|
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
|
||||||
|
JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
|
||||||
|
// GetKnownUsers searches all users that userID knows about.
|
||||||
|
GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
|
||||||
|
// GetKnownRooms returns a list of all rooms we know about.
|
||||||
|
GetKnownRooms(ctx context.Context) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,9 @@ package postgres
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/lib/pq"
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
|
@ -62,6 +64,10 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
|
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||||
|
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" +
|
||||||
|
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " GROUP BY target_nid"
|
||||||
|
|
||||||
// Insert a row in to membership table so that it can be locked by the
|
// Insert a row in to membership table so that it can be locked by the
|
||||||
// SELECT FOR UPDATE
|
// SELECT FOR UPDATE
|
||||||
const insertMembershipSQL = "" +
|
const insertMembershipSQL = "" +
|
||||||
|
@ -99,6 +105,19 @@ const updateMembershipSQL = "" +
|
||||||
"UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4, event_nid = $5" +
|
"UPDATE roomserver_membership SET sender_nid = $3, membership_nid = $4, event_nid = $5" +
|
||||||
" WHERE room_nid = $1 AND target_nid = $2"
|
" WHERE room_nid = $1 AND target_nid = $2"
|
||||||
|
|
||||||
|
const selectRoomsWithMembershipSQL = "" +
|
||||||
|
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2"
|
||||||
|
|
||||||
|
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
|
||||||
|
// joined to. Since this information is used to populate the user directory, we will
|
||||||
|
// only return users that the user would ordinarily be able to see anyway.
|
||||||
|
var selectKnownUsersSQL = "" +
|
||||||
|
"SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " +
|
||||||
|
"roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
|
||||||
|
" WHERE room_nid = ANY(" +
|
||||||
|
" SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
|
||||||
|
") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3"
|
||||||
|
|
||||||
type membershipStatements struct {
|
type membershipStatements struct {
|
||||||
insertMembershipStmt *sql.Stmt
|
insertMembershipStmt *sql.Stmt
|
||||||
selectMembershipForUpdateStmt *sql.Stmt
|
selectMembershipForUpdateStmt *sql.Stmt
|
||||||
|
@ -108,6 +127,9 @@ type membershipStatements struct {
|
||||||
selectMembershipsFromRoomStmt *sql.Stmt
|
selectMembershipsFromRoomStmt *sql.Stmt
|
||||||
selectLocalMembershipsFromRoomStmt *sql.Stmt
|
selectLocalMembershipsFromRoomStmt *sql.Stmt
|
||||||
updateMembershipStmt *sql.Stmt
|
updateMembershipStmt *sql.Stmt
|
||||||
|
selectRoomsWithMembershipStmt *sql.Stmt
|
||||||
|
selectJoinedUsersSetForRoomsStmt *sql.Stmt
|
||||||
|
selectKnownUsersStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
|
func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
|
||||||
|
@ -126,6 +148,9 @@ func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
|
||||||
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
|
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
|
||||||
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
|
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
|
||||||
{&s.updateMembershipStmt, updateMembershipSQL},
|
{&s.updateMembershipStmt, updateMembershipSQL},
|
||||||
|
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
|
||||||
|
{&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL},
|
||||||
|
{&s.selectKnownUsersStmt, selectKnownUsersSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -222,3 +247,61 @@ func (s *membershipStatements) UpdateMembership(
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) SelectRoomsWithMembership(
|
||||||
|
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||||
|
) ([]types.RoomNID, error) {
|
||||||
|
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithMembership: rows.close() failed")
|
||||||
|
var roomNIDs []types.RoomNID
|
||||||
|
for rows.Next() {
|
||||||
|
var roomNID types.RoomNID
|
||||||
|
if err := rows.Scan(&roomNID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
roomNIDs = append(roomNIDs, roomNID)
|
||||||
|
}
|
||||||
|
return roomNIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
|
||||||
|
roomIDarray := make([]int64, len(roomNIDs))
|
||||||
|
for i := range roomNIDs {
|
||||||
|
roomIDarray[i] = int64(roomNIDs[i])
|
||||||
|
}
|
||||||
|
rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
|
||||||
|
result := make(map[types.EventStateKeyNID]int)
|
||||||
|
for rows.Next() {
|
||||||
|
var userID types.EventStateKeyNID
|
||||||
|
var count int
|
||||||
|
if err := rows.Scan(&userID, &count); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result[userID] = count
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
|
||||||
|
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result := []string{}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var userID string
|
||||||
|
if err := rows.Scan(&userID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, userID)
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
|
@ -68,24 +69,32 @@ const selectLatestEventNIDsForUpdateSQL = "" +
|
||||||
const updateLatestEventNIDsSQL = "" +
|
const updateLatestEventNIDsSQL = "" +
|
||||||
"UPDATE roomserver_rooms SET latest_event_nids = $2, last_event_sent_nid = $3, state_snapshot_nid = $4 WHERE room_nid = $1"
|
"UPDATE roomserver_rooms SET latest_event_nids = $2, last_event_sent_nid = $3, state_snapshot_nid = $4 WHERE room_nid = $1"
|
||||||
|
|
||||||
const selectRoomVersionForRoomIDSQL = "" +
|
|
||||||
"SELECT room_version FROM roomserver_rooms WHERE room_id = $1"
|
|
||||||
|
|
||||||
const selectRoomVersionForRoomNIDSQL = "" +
|
const selectRoomVersionForRoomNIDSQL = "" +
|
||||||
"SELECT room_version FROM roomserver_rooms WHERE room_nid = $1"
|
"SELECT room_version FROM roomserver_rooms WHERE room_nid = $1"
|
||||||
|
|
||||||
const selectRoomInfoSQL = "" +
|
const selectRoomInfoSQL = "" +
|
||||||
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
|
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
|
||||||
|
|
||||||
|
const selectRoomIDsSQL = "" +
|
||||||
|
"SELECT room_id FROM roomserver_rooms"
|
||||||
|
|
||||||
|
const bulkSelectRoomIDsSQL = "" +
|
||||||
|
"SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
|
||||||
|
|
||||||
|
const bulkSelectRoomNIDsSQL = "" +
|
||||||
|
"SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)"
|
||||||
|
|
||||||
type roomStatements struct {
|
type roomStatements struct {
|
||||||
insertRoomNIDStmt *sql.Stmt
|
insertRoomNIDStmt *sql.Stmt
|
||||||
selectRoomNIDStmt *sql.Stmt
|
selectRoomNIDStmt *sql.Stmt
|
||||||
selectLatestEventNIDsStmt *sql.Stmt
|
selectLatestEventNIDsStmt *sql.Stmt
|
||||||
selectLatestEventNIDsForUpdateStmt *sql.Stmt
|
selectLatestEventNIDsForUpdateStmt *sql.Stmt
|
||||||
updateLatestEventNIDsStmt *sql.Stmt
|
updateLatestEventNIDsStmt *sql.Stmt
|
||||||
selectRoomVersionForRoomIDStmt *sql.Stmt
|
|
||||||
selectRoomVersionForRoomNIDStmt *sql.Stmt
|
selectRoomVersionForRoomNIDStmt *sql.Stmt
|
||||||
selectRoomInfoStmt *sql.Stmt
|
selectRoomInfoStmt *sql.Stmt
|
||||||
|
selectRoomIDsStmt *sql.Stmt
|
||||||
|
bulkSelectRoomIDsStmt *sql.Stmt
|
||||||
|
bulkSelectRoomNIDsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||||
|
@ -100,12 +109,30 @@ func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||||
{&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
|
{&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
|
||||||
{&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL},
|
{&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL},
|
||||||
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
|
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
|
||||||
{&s.selectRoomVersionForRoomIDStmt, selectRoomVersionForRoomIDSQL},
|
|
||||||
{&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL},
|
{&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL},
|
||||||
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
|
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
|
||||||
|
{&s.selectRoomIDsStmt, selectRoomIDsSQL},
|
||||||
|
{&s.bulkSelectRoomIDsStmt, bulkSelectRoomIDsSQL},
|
||||||
|
{&s.bulkSelectRoomNIDsStmt, bulkSelectRoomNIDsSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
|
||||||
|
rows, err := s.selectRoomIDsStmt.QueryContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
|
||||||
|
var roomIDs []string
|
||||||
|
for rows.Next() {
|
||||||
|
var roomID string
|
||||||
|
if err = rows.Scan(&roomID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
roomIDs = append(roomIDs, roomID)
|
||||||
|
}
|
||||||
|
return roomIDs, nil
|
||||||
|
}
|
||||||
func (s *roomStatements) InsertRoomNID(
|
func (s *roomStatements) InsertRoomNID(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
roomID string, roomVersion gomatrixserverlib.RoomVersion,
|
roomID string, roomVersion gomatrixserverlib.RoomVersion,
|
||||||
|
@ -192,18 +219,6 @@ func (s *roomStatements) UpdateLatestEventNIDs(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) SelectRoomVersionForRoomID(
|
|
||||||
ctx context.Context, txn *sql.Tx, roomID string,
|
|
||||||
) (gomatrixserverlib.RoomVersion, error) {
|
|
||||||
var roomVersion gomatrixserverlib.RoomVersion
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectRoomVersionForRoomIDStmt)
|
|
||||||
err := stmt.QueryRowContext(ctx, roomID).Scan(&roomVersion)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return roomVersion, errors.New("room not found")
|
|
||||||
}
|
|
||||||
return roomVersion, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *roomStatements) SelectRoomVersionForRoomNID(
|
func (s *roomStatements) SelectRoomVersionForRoomNID(
|
||||||
ctx context.Context, roomNID types.RoomNID,
|
ctx context.Context, roomNID types.RoomNID,
|
||||||
) (gomatrixserverlib.RoomVersion, error) {
|
) (gomatrixserverlib.RoomVersion, error) {
|
||||||
|
@ -214,3 +229,45 @@ func (s *roomStatements) SelectRoomVersionForRoomNID(
|
||||||
}
|
}
|
||||||
return roomVersion, err
|
return roomVersion, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
|
||||||
|
var array pq.Int64Array
|
||||||
|
for _, nid := range roomNIDs {
|
||||||
|
array = append(array, int64(nid))
|
||||||
|
}
|
||||||
|
rows, err := s.bulkSelectRoomIDsStmt.QueryContext(ctx, array)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
|
||||||
|
var roomIDs []string
|
||||||
|
for rows.Next() {
|
||||||
|
var roomID string
|
||||||
|
if err = rows.Scan(&roomID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
roomIDs = append(roomIDs, roomID)
|
||||||
|
}
|
||||||
|
return roomIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) {
|
||||||
|
var array pq.StringArray
|
||||||
|
for _, roomID := range roomIDs {
|
||||||
|
array = append(array, roomID)
|
||||||
|
}
|
||||||
|
rows, err := s.bulkSelectRoomNIDsStmt.QueryContext(ctx, array)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
|
||||||
|
var roomNIDs []types.RoomNID
|
||||||
|
for rows.Next() {
|
||||||
|
var roomNID types.RoomNID
|
||||||
|
if err = rows.Scan(&roomNID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
roomNIDs = append(roomNIDs, roomNID)
|
||||||
|
}
|
||||||
|
return roomNIDs, nil
|
||||||
|
}
|
||||||
|
|
|
@ -12,15 +12,15 @@ import (
|
||||||
type LatestEventsUpdater struct {
|
type LatestEventsUpdater struct {
|
||||||
transaction
|
transaction
|
||||||
d *Database
|
d *Database
|
||||||
roomNID types.RoomNID
|
roomInfo types.RoomInfo
|
||||||
latestEvents []types.StateAtEventAndReference
|
latestEvents []types.StateAtEventAndReference
|
||||||
lastEventIDSent string
|
lastEventIDSent string
|
||||||
currentStateSnapshotNID types.StateSnapshotNID
|
currentStateSnapshotNID types.StateSnapshotNID
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomNID types.RoomNID) (*LatestEventsUpdater, error) {
|
func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) {
|
||||||
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
|
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
|
||||||
d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomNID)
|
d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
txn.Rollback() // nolint: errcheck
|
txn.Rollback() // nolint: errcheck
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -39,14 +39,13 @@ func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomN
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &LatestEventsUpdater{
|
return &LatestEventsUpdater{
|
||||||
transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
|
transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoomVersion implements types.RoomRecentEventsUpdater
|
// RoomVersion implements types.RoomRecentEventsUpdater
|
||||||
func (u *LatestEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) {
|
func (u *LatestEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) {
|
||||||
version, _ = u.d.GetRoomVersionForRoomNID(u.ctx, u.roomNID)
|
return u.roomInfo.RoomVersion
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LatestEvents implements types.RoomRecentEventsUpdater
|
// LatestEvents implements types.RoomRecentEventsUpdater
|
||||||
|
@ -118,5 +117,5 @@ func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
|
func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
|
||||||
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID, targetLocal)
|
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,13 +5,16 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
csstables "github.com/matrix-org/dendrite/currentstateserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -229,30 +232,6 @@ func (d *Database) StateEntries(
|
||||||
return d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs)
|
return d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetRoomVersionForRoom(
|
|
||||||
ctx context.Context, roomID string,
|
|
||||||
) (gomatrixserverlib.RoomVersion, error) {
|
|
||||||
if roomVersion, ok := d.Cache.GetRoomVersion(roomID); ok {
|
|
||||||
return roomVersion, nil
|
|
||||||
}
|
|
||||||
return d.RoomsTable.SelectRoomVersionForRoomID(
|
|
||||||
ctx, nil, roomID,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) GetRoomVersionForRoomNID(
|
|
||||||
ctx context.Context, roomNID types.RoomNID,
|
|
||||||
) (gomatrixserverlib.RoomVersion, error) {
|
|
||||||
if roomID, ok := d.Cache.GetRoomServerRoomID(roomNID); ok {
|
|
||||||
if roomVersion, ok := d.Cache.GetRoomVersion(roomID); ok {
|
|
||||||
return roomVersion, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return d.RoomsTable.SelectRoomVersionForRoomNID(
|
|
||||||
ctx, roomNID,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
|
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.RoomAliasesTable.InsertRoomAlias(ctx, txn, alias, roomID, creatorUserID)
|
return d.RoomAliasesTable.InsertRoomAlias(ctx, txn, alias, roomID, creatorUserID)
|
||||||
|
@ -387,7 +366,7 @@ func (d *Database) MembershipUpdater(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetLatestEventsForUpdate(
|
func (d *Database) GetLatestEventsForUpdate(
|
||||||
ctx context.Context, roomNID types.RoomNID,
|
ctx context.Context, roomInfo types.RoomInfo,
|
||||||
) (*LatestEventsUpdater, error) {
|
) (*LatestEventsUpdater, error) {
|
||||||
txn, err := d.DB.Begin()
|
txn, err := d.DB.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -395,7 +374,7 @@ func (d *Database) GetLatestEventsForUpdate(
|
||||||
}
|
}
|
||||||
var updater *LatestEventsUpdater
|
var updater *LatestEventsUpdater
|
||||||
_ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
|
_ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
|
||||||
updater, err = NewLatestEventsUpdater(ctx, d, txn, roomNID)
|
updater, err = NewLatestEventsUpdater(ctx, d, txn, roomInfo)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
return updater, err
|
return updater, err
|
||||||
|
@ -735,3 +714,190 @@ func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event {
|
||||||
}
|
}
|
||||||
return &evs[0]
|
return &evs[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetStateEvent returns the current state event of a given type for a given room with a given state key
|
||||||
|
// If no event could be found, returns nil
|
||||||
|
// If there was an issue during the retrieval, returns an error
|
||||||
|
func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
|
roomInfo, err := d.RoomInfo(ctx, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, stateKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// return the event requested
|
||||||
|
for _, e := range entries {
|
||||||
|
if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID {
|
||||||
|
data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, []types.EventNID{e.EventNID})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil, fmt.Errorf("GetStateEvent: no json for event nid %d", e.EventNID)
|
||||||
|
}
|
||||||
|
ev, err := gomatrixserverlib.NewEventFromTrustedJSON(data[0].EventJSON, false, roomInfo.RoomVersion)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
h := ev.Headered(roomInfo.RoomVersion)
|
||||||
|
return &h, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("GetStateEvent: no event type '%s' with key '%s' exists in room %s", evType, stateKey, roomID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
|
||||||
|
func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) {
|
||||||
|
var membershipState tables.MembershipState
|
||||||
|
switch membership {
|
||||||
|
case "join":
|
||||||
|
membershipState = tables.MembershipStateJoin
|
||||||
|
case "invite":
|
||||||
|
membershipState = tables.MembershipStateInvite
|
||||||
|
case "leave":
|
||||||
|
membershipState = tables.MembershipStateLeaveOrBan
|
||||||
|
case "ban":
|
||||||
|
membershipState = tables.MembershipStateLeaveOrBan
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership)
|
||||||
|
}
|
||||||
|
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
|
||||||
|
}
|
||||||
|
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, stateKeyNID, membershipState)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, roomNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(roomIDs) != len(roomNIDs) {
|
||||||
|
return nil, fmt.Errorf("GetRoomsByMembership: missing room IDs, got %d want %d", len(roomIDs), len(roomNIDs))
|
||||||
|
}
|
||||||
|
return roomIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
|
||||||
|
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
|
||||||
|
func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]csstables.StrippedEvent, error) {
|
||||||
|
return nil, fmt.Errorf("not implemented yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
|
||||||
|
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
|
||||||
|
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, roomIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, roomNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stateKeyNIDs := make([]types.EventStateKeyNID, len(userNIDToCount))
|
||||||
|
i := 0
|
||||||
|
for nid := range userNIDToCount {
|
||||||
|
stateKeyNIDs[i] = nid
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, stateKeyNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(nidToUserID) != len(userNIDToCount) {
|
||||||
|
return nil, fmt.Errorf("found %d users but only have state key nids for %d of them", len(userNIDToCount), len(nidToUserID))
|
||||||
|
}
|
||||||
|
result := make(map[string]int, len(userNIDToCount))
|
||||||
|
for nid, count := range userNIDToCount {
|
||||||
|
result[nidToUserID[nid]] = count
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetKnownUsers searches all users that userID knows about.
|
||||||
|
func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) {
|
||||||
|
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return d.MembershipTable.SelectKnownUsers(ctx, stateKeyNID, searchString, limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetKnownRooms returns a list of all rooms we know about.
|
||||||
|
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
|
||||||
|
return d.RoomsTable.SelectRoomIDs(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops
|
||||||
|
// it should live in this package!
|
||||||
|
|
||||||
|
func (d *Database) loadStateAtSnapshot(
|
||||||
|
ctx context.Context, stateNID types.StateSnapshotNID,
|
||||||
|
) ([]types.StateEntry, error) {
|
||||||
|
stateBlockNIDLists, err := d.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
|
||||||
|
stateBlockNIDList := stateBlockNIDLists[0]
|
||||||
|
|
||||||
|
stateEntryLists, err := d.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stateEntriesMap := stateEntryListMap(stateEntryLists)
|
||||||
|
|
||||||
|
// Combine all the state entries for this snapshot.
|
||||||
|
// The order of state block NIDs in the list tells us the order to combine them in.
|
||||||
|
var fullState []types.StateEntry
|
||||||
|
for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs {
|
||||||
|
entries, ok := stateEntriesMap.lookup(stateBlockNID)
|
||||||
|
if !ok {
|
||||||
|
// This should only get hit if the database is corrupt.
|
||||||
|
// It should be impossible for an event to reference a NID that doesn't exist
|
||||||
|
panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID))
|
||||||
|
}
|
||||||
|
fullState = append(fullState, entries...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stable sort so that the most recent entry for each state key stays
|
||||||
|
// remains later in the list than the older entries for the same state key.
|
||||||
|
sort.Stable(stateEntryByStateKeySorter(fullState))
|
||||||
|
// Unique returns the last entry and hence the most recent entry for each state key.
|
||||||
|
fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))]
|
||||||
|
return fullState, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type stateEntryListMap []types.StateEntryList
|
||||||
|
|
||||||
|
func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) {
|
||||||
|
list := []types.StateEntryList(m)
|
||||||
|
i := sort.Search(len(list), func(i int) bool {
|
||||||
|
return list[i].StateBlockNID >= stateBlockNID
|
||||||
|
})
|
||||||
|
if i < len(list) && list[i].StateBlockNID == stateBlockNID {
|
||||||
|
ok = true
|
||||||
|
stateEntries = list[i].StateEntries
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type stateEntryByStateKeySorter []types.StateEntry
|
||||||
|
|
||||||
|
func (s stateEntryByStateKeySorter) Len() int { return len(s) }
|
||||||
|
func (s stateEntryByStateKeySorter) Less(i, j int) bool {
|
||||||
|
return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple)
|
||||||
|
}
|
||||||
|
func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||||
|
|
|
@ -18,6 +18,8 @@ package sqlite3
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
@ -38,6 +40,10 @@ const membershipSchema = `
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
|
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||||
|
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" +
|
||||||
|
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " GROUP BY target_nid"
|
||||||
|
|
||||||
// Insert a row in to membership table so that it can be locked by the
|
// Insert a row in to membership table so that it can be locked by the
|
||||||
// SELECT FOR UPDATE
|
// SELECT FOR UPDATE
|
||||||
const insertMembershipSQL = "" +
|
const insertMembershipSQL = "" +
|
||||||
|
@ -75,6 +81,19 @@ const updateMembershipSQL = "" +
|
||||||
"UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3" +
|
"UPDATE roomserver_membership SET sender_nid = $1, membership_nid = $2, event_nid = $3" +
|
||||||
" WHERE room_nid = $4 AND target_nid = $5"
|
" WHERE room_nid = $4 AND target_nid = $5"
|
||||||
|
|
||||||
|
const selectRoomsWithMembershipSQL = "" +
|
||||||
|
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2"
|
||||||
|
|
||||||
|
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
|
||||||
|
// joined to. Since this information is used to populate the user directory, we will
|
||||||
|
// only return users that the user would ordinarily be able to see anyway.
|
||||||
|
var selectKnownUsersSQL = "" +
|
||||||
|
"SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " +
|
||||||
|
"roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
|
||||||
|
" WHERE room_nid IN (" +
|
||||||
|
" SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
|
||||||
|
") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3"
|
||||||
|
|
||||||
type membershipStatements struct {
|
type membershipStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertMembershipStmt *sql.Stmt
|
insertMembershipStmt *sql.Stmt
|
||||||
|
@ -84,7 +103,9 @@ type membershipStatements struct {
|
||||||
selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt
|
selectLocalMembershipsFromRoomAndMembershipStmt *sql.Stmt
|
||||||
selectMembershipsFromRoomStmt *sql.Stmt
|
selectMembershipsFromRoomStmt *sql.Stmt
|
||||||
selectLocalMembershipsFromRoomStmt *sql.Stmt
|
selectLocalMembershipsFromRoomStmt *sql.Stmt
|
||||||
|
selectRoomsWithMembershipStmt *sql.Stmt
|
||||||
updateMembershipStmt *sql.Stmt
|
updateMembershipStmt *sql.Stmt
|
||||||
|
selectKnownUsersStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
|
func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
|
||||||
|
@ -105,6 +126,8 @@ func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
|
||||||
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
|
{&s.selectMembershipsFromRoomStmt, selectMembershipsFromRoomSQL},
|
||||||
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
|
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
|
||||||
{&s.updateMembershipStmt, updateMembershipSQL},
|
{&s.updateMembershipStmt, updateMembershipSQL},
|
||||||
|
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
|
||||||
|
{&s.selectKnownUsersStmt, selectKnownUsersSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -203,3 +226,62 @@ func (s *membershipStatements) UpdateMembership(
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) SelectRoomsWithMembership(
|
||||||
|
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||||
|
) ([]types.RoomNID, error) {
|
||||||
|
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithMembership: rows.close() failed")
|
||||||
|
var roomNIDs []types.RoomNID
|
||||||
|
for rows.Next() {
|
||||||
|
var roomNID types.RoomNID
|
||||||
|
if err := rows.Scan(&roomNID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
roomNIDs = append(roomNIDs, roomNID)
|
||||||
|
}
|
||||||
|
return roomNIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
|
||||||
|
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
||||||
|
for i, v := range roomNIDs {
|
||||||
|
iRoomNIDs[i] = v
|
||||||
|
}
|
||||||
|
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
|
||||||
|
rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
|
||||||
|
result := make(map[types.EventStateKeyNID]int)
|
||||||
|
for rows.Next() {
|
||||||
|
var userID types.EventStateKeyNID
|
||||||
|
var count int
|
||||||
|
if err := rows.Scan(&userID, &count); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result[userID] = count
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
|
||||||
|
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result := []string{}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var userID string
|
||||||
|
if err := rows.Scan(&userID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, userID)
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
|
@ -21,7 +21,9 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
|
@ -58,15 +60,21 @@ const selectLatestEventNIDsForUpdateSQL = "" +
|
||||||
const updateLatestEventNIDsSQL = "" +
|
const updateLatestEventNIDsSQL = "" +
|
||||||
"UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4"
|
"UPDATE roomserver_rooms SET latest_event_nids = $1, last_event_sent_nid = $2, state_snapshot_nid = $3 WHERE room_nid = $4"
|
||||||
|
|
||||||
const selectRoomVersionForRoomIDSQL = "" +
|
|
||||||
"SELECT room_version FROM roomserver_rooms WHERE room_id = $1"
|
|
||||||
|
|
||||||
const selectRoomVersionForRoomNIDSQL = "" +
|
const selectRoomVersionForRoomNIDSQL = "" +
|
||||||
"SELECT room_version FROM roomserver_rooms WHERE room_nid = $1"
|
"SELECT room_version FROM roomserver_rooms WHERE room_nid = $1"
|
||||||
|
|
||||||
const selectRoomInfoSQL = "" +
|
const selectRoomInfoSQL = "" +
|
||||||
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
|
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
|
||||||
|
|
||||||
|
const selectRoomIDsSQL = "" +
|
||||||
|
"SELECT room_id FROM roomserver_rooms"
|
||||||
|
|
||||||
|
const bulkSelectRoomIDsSQL = "" +
|
||||||
|
"SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
|
||||||
|
|
||||||
|
const bulkSelectRoomNIDsSQL = "" +
|
||||||
|
"SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)"
|
||||||
|
|
||||||
type roomStatements struct {
|
type roomStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertRoomNIDStmt *sql.Stmt
|
insertRoomNIDStmt *sql.Stmt
|
||||||
|
@ -74,9 +82,9 @@ type roomStatements struct {
|
||||||
selectLatestEventNIDsStmt *sql.Stmt
|
selectLatestEventNIDsStmt *sql.Stmt
|
||||||
selectLatestEventNIDsForUpdateStmt *sql.Stmt
|
selectLatestEventNIDsForUpdateStmt *sql.Stmt
|
||||||
updateLatestEventNIDsStmt *sql.Stmt
|
updateLatestEventNIDsStmt *sql.Stmt
|
||||||
selectRoomVersionForRoomIDStmt *sql.Stmt
|
|
||||||
selectRoomVersionForRoomNIDStmt *sql.Stmt
|
selectRoomVersionForRoomNIDStmt *sql.Stmt
|
||||||
selectRoomInfoStmt *sql.Stmt
|
selectRoomInfoStmt *sql.Stmt
|
||||||
|
selectRoomIDsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||||
|
@ -93,12 +101,29 @@ func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||||
{&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
|
{&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL},
|
||||||
{&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL},
|
{&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL},
|
||||||
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
|
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
|
||||||
{&s.selectRoomVersionForRoomIDStmt, selectRoomVersionForRoomIDSQL},
|
|
||||||
{&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL},
|
{&s.selectRoomVersionForRoomNIDStmt, selectRoomVersionForRoomNIDSQL},
|
||||||
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
|
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
|
||||||
|
{&s.selectRoomIDsStmt, selectRoomIDsSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
|
||||||
|
rows, err := s.selectRoomIDsStmt.QueryContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
|
||||||
|
var roomIDs []string
|
||||||
|
for rows.Next() {
|
||||||
|
var roomID string
|
||||||
|
if err = rows.Scan(&roomID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
roomIDs = append(roomIDs, roomID)
|
||||||
|
}
|
||||||
|
return roomIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
|
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
|
||||||
var info types.RoomInfo
|
var info types.RoomInfo
|
||||||
var latestNIDsJSON string
|
var latestNIDsJSON string
|
||||||
|
@ -198,18 +223,6 @@ func (s *roomStatements) UpdateLatestEventNIDs(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) SelectRoomVersionForRoomID(
|
|
||||||
ctx context.Context, txn *sql.Tx, roomID string,
|
|
||||||
) (gomatrixserverlib.RoomVersion, error) {
|
|
||||||
var roomVersion gomatrixserverlib.RoomVersion
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectRoomVersionForRoomIDStmt)
|
|
||||||
err := stmt.QueryRowContext(ctx, roomID).Scan(&roomVersion)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return roomVersion, errors.New("room not found")
|
|
||||||
}
|
|
||||||
return roomVersion, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *roomStatements) SelectRoomVersionForRoomNID(
|
func (s *roomStatements) SelectRoomVersionForRoomNID(
|
||||||
ctx context.Context, roomNID types.RoomNID,
|
ctx context.Context, roomNID types.RoomNID,
|
||||||
) (gomatrixserverlib.RoomVersion, error) {
|
) (gomatrixserverlib.RoomVersion, error) {
|
||||||
|
@ -220,3 +233,47 @@ func (s *roomStatements) SelectRoomVersionForRoomNID(
|
||||||
}
|
}
|
||||||
return roomVersion, err
|
return roomVersion, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
|
||||||
|
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
||||||
|
for i, v := range roomNIDs {
|
||||||
|
iRoomNIDs[i] = v
|
||||||
|
}
|
||||||
|
sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
||||||
|
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
|
||||||
|
var roomIDs []string
|
||||||
|
for rows.Next() {
|
||||||
|
var roomID string
|
||||||
|
if err = rows.Scan(&roomID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
roomIDs = append(roomIDs, roomID)
|
||||||
|
}
|
||||||
|
return roomIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) {
|
||||||
|
iRoomIDs := make([]interface{}, len(roomIDs))
|
||||||
|
for i, v := range roomIDs {
|
||||||
|
iRoomIDs[i] = v
|
||||||
|
}
|
||||||
|
sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1)
|
||||||
|
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
|
||||||
|
var roomNIDs []types.RoomNID
|
||||||
|
for rows.Next() {
|
||||||
|
var roomNID types.RoomNID
|
||||||
|
if err = rows.Scan(&roomNID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
roomNIDs = append(roomNIDs, roomNID)
|
||||||
|
}
|
||||||
|
return roomNIDs, nil
|
||||||
|
}
|
||||||
|
|
|
@ -150,7 +150,7 @@ func (d *Database) SupportsConcurrentRoomInputs() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetLatestEventsForUpdate(
|
func (d *Database) GetLatestEventsForUpdate(
|
||||||
ctx context.Context, roomNID types.RoomNID,
|
ctx context.Context, roomInfo types.RoomInfo,
|
||||||
) (*shared.LatestEventsUpdater, error) {
|
) (*shared.LatestEventsUpdater, error) {
|
||||||
// TODO: Do not use transactions. We should be holding open this transaction but we cannot have
|
// TODO: Do not use transactions. We should be holding open this transaction but we cannot have
|
||||||
// multiple write transactions on sqlite. The code will perform additional
|
// multiple write transactions on sqlite. The code will perform additional
|
||||||
|
@ -158,7 +158,7 @@ func (d *Database) GetLatestEventsForUpdate(
|
||||||
// 'database is locked' errors. As sqlite doesn't support multi-process on the
|
// 'database is locked' errors. As sqlite doesn't support multi-process on the
|
||||||
// same DB anyway, and we only execute updates sequentially, the only worries
|
// same DB anyway, and we only execute updates sequentially, the only worries
|
||||||
// are for rolling back when things go wrong. (atomicity)
|
// are for rolling back when things go wrong. (atomicity)
|
||||||
return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomNID)
|
return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) MembershipUpdater(
|
func (d *Database) MembershipUpdater(
|
||||||
|
|
|
@ -63,9 +63,11 @@ type Rooms interface {
|
||||||
SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error)
|
SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error)
|
||||||
SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error)
|
SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error)
|
||||||
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
|
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
|
||||||
SelectRoomVersionForRoomID(ctx context.Context, txn *sql.Tx, roomID string) (gomatrixserverlib.RoomVersion, error)
|
|
||||||
SelectRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error)
|
SelectRoomVersionForRoomNID(ctx context.Context, roomNID types.RoomNID) (gomatrixserverlib.RoomVersion, error)
|
||||||
SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
|
SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
|
||||||
|
SelectRoomIDs(ctx context.Context) ([]string, error)
|
||||||
|
BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error)
|
||||||
|
BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Transactions interface {
|
type Transactions interface {
|
||||||
|
@ -121,6 +123,11 @@ type Membership interface {
|
||||||
SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
|
SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||||
SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
|
SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||||
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error
|
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error
|
||||||
|
SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
|
||||||
|
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
|
||||||
|
// counts of how many rooms they are joined.
|
||||||
|
SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
|
||||||
|
SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Published interface {
|
type Published interface {
|
||||||
|
|
Loading…
Reference in a new issue