Support for server ACLs (#1261)

* First pass at server ACLs (not efficient)

* Use transaction origin, update whitelist

* Fix federation API test

It's sufficient for us to return nothing in response to current state, so that the server ACL check returns no ACLs.

* More efficient server ACLs - hopefully

* Fix queries

* Fix queries

* Avoid panics by nil pointers

* Bug fixes

* Fix state event type

* Fix mutex

* Update logging

* Ignore port when matching servername

* Use read mutex

* Fix bugs

* Fix sync API test

* Comments

* Add tests, tweaks to behaviour

* Fix test output
This commit is contained in:
Neil Alexander 2020-08-11 18:19:11 +01:00 committed by GitHub
parent 8b6ab272fb
commit bcdf9577a3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 581 additions and 16 deletions

View file

@ -0,0 +1,156 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package acls
import (
"context"
"encoding/json"
"fmt"
"net"
"regexp"
"strings"
"sync"
"github.com/matrix-org/dendrite/currentstateserver/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
type ServerACLs struct {
acls map[string]*serverACL // room ID -> ACL
aclsMutex sync.RWMutex // protects the above
}
func NewServerACLs(db storage.Database) *ServerACLs {
ctx := context.TODO()
acls := &ServerACLs{
acls: make(map[string]*serverACL),
}
// Look up all of the rooms that the current state server knows about.
rooms, err := db.GetKnownRooms(ctx)
if err != nil {
logrus.WithError(err).Fatalf("Failed to get known rooms")
}
// For each room, let's see if we have a server ACL state event. If we
// do then we'll process it into memory so that we have the regexes to
// hand.
for _, room := range rooms {
state, err := db.GetStateEvent(ctx, room, "m.room.server_acl", "")
if err != nil {
logrus.WithError(err).Errorf("Failed to get server ACLs for room %q", room)
continue
}
if state != nil {
acls.OnServerACLUpdate(&state.Event)
}
}
return acls
}
type ServerACL struct {
Allowed []string `json:"allow"`
Denied []string `json:"deny"`
AllowIPLiterals bool `json:"allow_ip_literals"`
}
type serverACL struct {
ServerACL
allowedRegexes []*regexp.Regexp
deniedRegexes []*regexp.Regexp
}
func compileACLRegex(orig string) (*regexp.Regexp, error) {
escaped := regexp.QuoteMeta(orig)
escaped = strings.Replace(escaped, "\\?", ".", -1)
escaped = strings.Replace(escaped, "\\*", ".*", -1)
return regexp.Compile(escaped)
}
func (s *ServerACLs) OnServerACLUpdate(state *gomatrixserverlib.Event) {
acls := &serverACL{}
if err := json.Unmarshal(state.Content(), &acls.ServerACL); err != nil {
logrus.WithError(err).Errorf("Failed to unmarshal state content for server ACLs")
return
}
// The spec calls only for * (zero or more chars) and ? (exactly one char)
// to be supported as wildcard components, so we will escape all of the regex
// special characters and then replace * and ? with their regex counterparts.
// https://matrix.org/docs/spec/client_server/r0.6.1#m-room-server-acl
for _, orig := range acls.Allowed {
if expr, err := compileACLRegex(orig); err != nil {
logrus.WithError(err).Errorf("Failed to compile allowed regex")
} else {
acls.allowedRegexes = append(acls.allowedRegexes, expr)
}
}
for _, orig := range acls.Denied {
if expr, err := compileACLRegex(orig); err != nil {
logrus.WithError(err).Errorf("Failed to compile denied regex")
} else {
acls.deniedRegexes = append(acls.deniedRegexes, expr)
}
}
logrus.WithFields(logrus.Fields{
"allow_ip_literals": acls.AllowIPLiterals,
"num_allowed": len(acls.allowedRegexes),
"num_denied": len(acls.deniedRegexes),
}).Debugf("Updating server ACLs for %q", state.RoomID())
s.aclsMutex.Lock()
defer s.aclsMutex.Unlock()
s.acls[state.RoomID()] = acls
}
func (s *ServerACLs) IsServerBannedFromRoom(serverName gomatrixserverlib.ServerName, roomID string) bool {
s.aclsMutex.RLock()
// First of all check if we have an ACL for this room. If we don't then
// no servers are banned from the room.
acls, ok := s.acls[roomID]
if !ok {
s.aclsMutex.RUnlock()
return false
}
s.aclsMutex.RUnlock()
// Split the host and port apart. This is because the spec calls on us to
// validate the hostname only in cases where the port is also present.
if serverNameOnly, _, err := net.SplitHostPort(string(serverName)); err == nil {
serverName = gomatrixserverlib.ServerName(serverNameOnly)
}
// Check if the hostname is an IPv4 or IPv6 literal. We cheat here by adding
// a /0 prefix length just to trick ParseCIDR into working. If we find that
// the server is an IP literal and we don't allow those then stop straight
// away.
if _, _, err := net.ParseCIDR(fmt.Sprintf("%s/0", serverName)); err == nil {
if !acls.AllowIPLiterals {
return true
}
}
// Check if the hostname matches one of the denied regexes. If it does then
// the server is banned from the room.
for _, expr := range acls.deniedRegexes {
if expr.MatchString(string(serverName)) {
return true
}
}
// Check if the hostname matches one of the allowed regexes. If it does then
// the server is NOT banned from the room.
for _, expr := range acls.allowedRegexes {
if expr.MatchString(string(serverName)) {
return false
}
}
// If we've got to this point then we haven't matched any regexes or an IP
// hostname if disallowed. The spec calls for default-deny here.
return true
}

View file

@ -0,0 +1,105 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package acls
import (
"regexp"
"testing"
)
func TestOpenACLsWithBlacklist(t *testing.T) {
roomID := "!test:test.com"
allowRegex, err := compileACLRegex("*")
if err != nil {
t.Fatalf(err.Error())
}
denyRegex, err := compileACLRegex("foo.com")
if err != nil {
t.Fatalf(err.Error())
}
acls := ServerACLs{
acls: make(map[string]*serverACL),
}
acls.acls[roomID] = &serverACL{
ServerACL: ServerACL{
AllowIPLiterals: true,
},
allowedRegexes: []*regexp.Regexp{allowRegex},
deniedRegexes: []*regexp.Regexp{denyRegex},
}
if acls.IsServerBannedFromRoom("1.2.3.4", roomID) {
t.Fatal("Expected 1.2.3.4 to be allowed but wasn't")
}
if acls.IsServerBannedFromRoom("1.2.3.4:2345", roomID) {
t.Fatal("Expected 1.2.3.4:2345 to be allowed but wasn't")
}
if !acls.IsServerBannedFromRoom("foo.com", roomID) {
t.Fatal("Expected foo.com to be banned but wasn't")
}
if !acls.IsServerBannedFromRoom("foo.com:3456", roomID) {
t.Fatal("Expected foo.com:3456 to be banned but wasn't")
}
if acls.IsServerBannedFromRoom("bar.com", roomID) {
t.Fatal("Expected bar.com to be allowed but wasn't")
}
if acls.IsServerBannedFromRoom("bar.com:4567", roomID) {
t.Fatal("Expected bar.com:4567 to be allowed but wasn't")
}
}
func TestDefaultACLsWithWhitelist(t *testing.T) {
roomID := "!test:test.com"
allowRegex, err := compileACLRegex("foo.com")
if err != nil {
t.Fatalf(err.Error())
}
acls := ServerACLs{
acls: make(map[string]*serverACL),
}
acls.acls[roomID] = &serverACL{
ServerACL: ServerACL{
AllowIPLiterals: false,
},
allowedRegexes: []*regexp.Regexp{allowRegex},
deniedRegexes: []*regexp.Regexp{},
}
if !acls.IsServerBannedFromRoom("1.2.3.4", roomID) {
t.Fatal("Expected 1.2.3.4 to be banned but wasn't")
}
if !acls.IsServerBannedFromRoom("1.2.3.4:2345", roomID) {
t.Fatal("Expected 1.2.3.4:2345 to be banned but wasn't")
}
if acls.IsServerBannedFromRoom("foo.com", roomID) {
t.Fatal("Expected foo.com to be allowed but wasn't")
}
if acls.IsServerBannedFromRoom("foo.com:3456", roomID) {
t.Fatal("Expected foo.com:3456 to be allowed but wasn't")
}
if !acls.IsServerBannedFromRoom("bar.com", roomID) {
t.Fatal("Expected bar.com to be allowed but wasn't")
}
if !acls.IsServerBannedFromRoom("baz.com", roomID) {
t.Fatal("Expected baz.com to be allowed but wasn't")
}
if !acls.IsServerBannedFromRoom("qux.com:4567", roomID) {
t.Fatal("Expected qux.com:4567 to be allowed but wasn't")
}
}

View file

@ -36,6 +36,8 @@ type CurrentStateInternalAPI interface {
QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
// QueryKnownUsers returns a list of users that we know about from our joined rooms. // QueryKnownUsers returns a list of users that we know about from our joined rooms.
QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
} }
type QuerySharedUsersRequest struct { type QuerySharedUsersRequest struct {
@ -101,6 +103,15 @@ type QueryKnownUsersResponse struct {
Users []authtypes.FullyQualifiedProfile `json:"profiles"` Users []authtypes.FullyQualifiedProfile `json:"profiles"`
} }
type QueryServerBannedFromRoomRequest struct {
ServerName gomatrixserverlib.ServerName `json:"server_name"`
RoomID string `json:"room_id"`
}
type QueryServerBannedFromRoomResponse struct {
Banned bool `json:"banned"`
}
// MarshalJSON stringifies the StateKeyTuple keys so they can be sent over the wire in HTTP API mode. // MarshalJSON stringifies the StateKeyTuple keys so they can be sent over the wire in HTTP API mode.
func (r *QueryCurrentStateResponse) MarshalJSON() ([]byte, error) { func (r *QueryCurrentStateResponse) MarshalJSON() ([]byte, error) {
se := make(map[string]*gomatrixserverlib.HeaderedEvent, len(r.StateEvents)) se := make(map[string]*gomatrixserverlib.HeaderedEvent, len(r.StateEvents))

View file

@ -1,3 +1,17 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package api package api
import ( import (
@ -25,6 +39,20 @@ func GetEvent(ctx context.Context, stateAPI CurrentStateInternalAPI, roomID stri
return nil return nil
} }
// IsServerBannedFromRoom returns whether the server is banned from a room by server ACLs.
func IsServerBannedFromRoom(ctx context.Context, stateAPI CurrentStateInternalAPI, roomID string, serverName gomatrixserverlib.ServerName) bool {
req := &QueryServerBannedFromRoomRequest{
ServerName: serverName,
RoomID: roomID,
}
res := &QueryServerBannedFromRoomResponse{}
if err := stateAPI.QueryServerBannedFromRoom(ctx, req, res); err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to QueryServerBannedFromRoom")
return true
}
return res.Banned
}
// PopulatePublicRooms extracts PublicRoom information for all the provided room IDs. The IDs are not checked to see if they are visible in the // PopulatePublicRooms extracts PublicRoom information for all the provided room IDs. The IDs are not checked to see if they are visible in the
// published room directory. // published room directory.
// due to lots of switches // due to lots of switches

View file

@ -19,6 +19,7 @@ import (
"encoding/json" "encoding/json"
"github.com/Shopify/sarama" "github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/currentstateserver/acls"
"github.com/matrix-org/dendrite/currentstateserver/storage" "github.com/matrix-org/dendrite/currentstateserver/storage"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -30,9 +31,10 @@ import (
type OutputRoomEventConsumer struct { type OutputRoomEventConsumer struct {
rsConsumer *internal.ContinualConsumer rsConsumer *internal.ContinualConsumer
db storage.Database db storage.Database
acls *acls.ServerACLs
} }
func NewOutputRoomEventConsumer(topicName string, kafkaConsumer sarama.Consumer, store storage.Database) *OutputRoomEventConsumer { func NewOutputRoomEventConsumer(topicName string, kafkaConsumer sarama.Consumer, store storage.Database, acls *acls.ServerACLs) *OutputRoomEventConsumer {
consumer := &internal.ContinualConsumer{ consumer := &internal.ContinualConsumer{
Topic: topicName, Topic: topicName,
Consumer: kafkaConsumer, Consumer: kafkaConsumer,
@ -41,6 +43,7 @@ func NewOutputRoomEventConsumer(topicName string, kafkaConsumer sarama.Consumer,
s := &OutputRoomEventConsumer{ s := &OutputRoomEventConsumer{
rsConsumer: consumer, rsConsumer: consumer,
db: store, db: store,
acls: acls,
} }
consumer.ProcessMessage = s.onMessage consumer.ProcessMessage = s.onMessage
@ -76,6 +79,10 @@ func (c *OutputRoomEventConsumer) onNewRoomEvent(
) error { ) error {
ev := msg.Event ev := msg.Event
if ev.Type() == "m.room.server_acl" && ev.StateKeyEquals("") {
defer c.acls.OnServerACLUpdate(&ev.Event)
}
addsStateEvents := msg.AddsState() addsStateEvents := msg.AddsState()
ev, err := c.updateStateEvent(ev) ev, err := c.updateStateEvent(ev)

View file

@ -17,6 +17,7 @@ package currentstateserver
import ( import (
"github.com/Shopify/sarama" "github.com/Shopify/sarama"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/matrix-org/dendrite/currentstateserver/acls"
"github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/currentstateserver/api"
"github.com/matrix-org/dendrite/currentstateserver/consumers" "github.com/matrix-org/dendrite/currentstateserver/consumers"
"github.com/matrix-org/dendrite/currentstateserver/internal" "github.com/matrix-org/dendrite/currentstateserver/internal"
@ -39,13 +40,15 @@ func NewInternalAPI(cfg *config.CurrentStateServer, consumer sarama.Consumer) ap
if err != nil { if err != nil {
logrus.WithError(err).Panicf("failed to open database") logrus.WithError(err).Panicf("failed to open database")
} }
serverACLs := acls.NewServerACLs(csDB)
roomConsumer := consumers.NewOutputRoomEventConsumer( roomConsumer := consumers.NewOutputRoomEventConsumer(
cfg.Matrix.Kafka.TopicFor(config.TopicOutputRoomEvent), consumer, csDB, cfg.Matrix.Kafka.TopicFor(config.TopicOutputRoomEvent), consumer, csDB, serverACLs,
) )
if err = roomConsumer.Start(); err != nil { if err = roomConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start room server consumer") logrus.WithError(err).Panicf("failed to start room server consumer")
} }
return &internal.CurrentStateInternalAPI{ return &internal.CurrentStateInternalAPI{
DB: csDB, DB: csDB,
ServerACLs: serverACLs,
} }
} }

View file

@ -16,8 +16,10 @@ package internal
import ( import (
"context" "context"
"errors"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/currentstateserver/acls"
"github.com/matrix-org/dendrite/currentstateserver/api" "github.com/matrix-org/dendrite/currentstateserver/api"
"github.com/matrix-org/dendrite/currentstateserver/storage" "github.com/matrix-org/dendrite/currentstateserver/storage"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -25,6 +27,7 @@ import (
type CurrentStateInternalAPI struct { type CurrentStateInternalAPI struct {
DB storage.Database DB storage.Database
ServerACLs *acls.ServerACLs
} }
func (a *CurrentStateInternalAPI) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error { func (a *CurrentStateInternalAPI) QueryCurrentState(ctx context.Context, req *api.QueryCurrentStateRequest, res *api.QueryCurrentStateResponse) error {
@ -112,3 +115,11 @@ func (a *CurrentStateInternalAPI) QuerySharedUsers(ctx context.Context, req *api
res.UserIDsToCount = users res.UserIDsToCount = users
return nil return nil
} }
func (a *CurrentStateInternalAPI) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse) error {
if a.ServerACLs == nil {
return errors.New("no server ACL tracking")
}
res.Banned = a.ServerACLs.IsServerBannedFromRoom(req.ServerName, req.RoomID)
return nil
}

View file

@ -31,6 +31,7 @@ const (
QueryBulkStateContentPath = "/currentstateserver/queryBulkStateContent" QueryBulkStateContentPath = "/currentstateserver/queryBulkStateContent"
QuerySharedUsersPath = "/currentstateserver/querySharedUsers" QuerySharedUsersPath = "/currentstateserver/querySharedUsers"
QueryKnownUsersPath = "/currentstateserver/queryKnownUsers" QueryKnownUsersPath = "/currentstateserver/queryKnownUsers"
QueryServerBannedFromRoomPath = "/currentstateserver/queryServerBannedFromRoom"
) )
// NewCurrentStateAPIClient creates a CurrentStateInternalAPI implemented by talking to a HTTP POST API. // NewCurrentStateAPIClient creates a CurrentStateInternalAPI implemented by talking to a HTTP POST API.
@ -108,3 +109,13 @@ func (h *httpCurrentStateInternalAPI) QueryKnownUsers(
apiURL := h.apiURL + QueryKnownUsersPath apiURL := h.apiURL + QueryKnownUsersPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
} }
func (h *httpCurrentStateInternalAPI) QueryServerBannedFromRoom(
ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerBannedFromRoom")
defer span.Finish()
apiURL := h.apiURL + QueryServerBannedFromRoomPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}

View file

@ -90,4 +90,17 @@ func AddRoutes(internalAPIMux *mux.Router, intAPI api.CurrentStateInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(QueryServerBannedFromRoomPath,
httputil.MakeInternalAPI("queryServerBannedFromRoom", func(req *http.Request) util.JSONResponse {
request := api.QueryServerBannedFromRoomRequest{}
response := api.QueryServerBannedFromRoomResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.QueryServerBannedFromRoom(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
} }

View file

@ -41,4 +41,6 @@ type Database interface {
JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
// GetKnownUsers searches all users that userID knows about. // GetKnownUsers searches all users that userID knows about.
GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
// GetKnownRooms returns a list of all rooms we know about.
GetKnownRooms(ctx context.Context) ([]string, error)
} }

View file

@ -82,6 +82,9 @@ const selectJoinedUsersSetForRoomsSQL = "" +
"SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id = ANY($1) AND" + "SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id = ANY($1) AND" +
" type = 'm.room.member' and content_value = 'join' GROUP BY state_key" " type = 'm.room.member' and content_value = 'join' GROUP BY state_key"
const selectKnownRoomsSQL = "" +
"SELECT DISTINCT room_id FROM currentstate_current_room_state"
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is // selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
// joined to. Since this information is used to populate the user directory, we will // joined to. Since this information is used to populate the user directory, we will
// only return users that the user would ordinarily be able to see anyway. // only return users that the user would ordinarily be able to see anyway.
@ -99,6 +102,7 @@ type currentRoomStateStatements struct {
selectBulkStateContentStmt *sql.Stmt selectBulkStateContentStmt *sql.Stmt
selectBulkStateContentWildStmt *sql.Stmt selectBulkStateContentWildStmt *sql.Stmt
selectJoinedUsersSetForRoomsStmt *sql.Stmt selectJoinedUsersSetForRoomsStmt *sql.Stmt
selectKnownRoomsStmt *sql.Stmt
selectKnownUsersStmt *sql.Stmt selectKnownUsersStmt *sql.Stmt
} }
@ -132,6 +136,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil { if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil {
return nil, err return nil, err
} }
if s.selectKnownRoomsStmt, err = db.Prepare(selectKnownRoomsSQL); err != nil {
return nil, err
}
if s.selectKnownUsersStmt, err = db.Prepare(selectKnownUsersSQL); err != nil { if s.selectKnownUsersStmt, err = db.Prepare(selectKnownUsersSQL); err != nil {
return nil, err return nil, err
} }
@ -325,3 +332,20 @@ func (s *currentRoomStateStatements) SelectKnownUsers(ctx context.Context, userI
} }
return result, rows.Err() return result, rows.Err()
} }
func (s *currentRoomStateStatements) SelectKnownRooms(ctx context.Context) ([]string, error) {
rows, err := s.selectKnownRoomsStmt.QueryContext(ctx)
if err != nil {
return nil, err
}
result := []string{}
defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownRooms: rows.close() failed")
for rows.Next() {
var roomID string
if err := rows.Scan(&roomID); err != nil {
return nil, err
}
result = append(result, roomID)
}
return result, rows.Err()
}

View file

@ -93,3 +93,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) { func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) {
return d.CurrentRoomState.SelectKnownUsers(ctx, userID, searchString, limit) return d.CurrentRoomState.SelectKnownUsers(ctx, userID, searchString, limit)
} }
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
return d.CurrentRoomState.SelectKnownRooms(ctx)
}

View file

@ -70,6 +70,9 @@ const selectBulkStateContentWildSQL = "" +
const selectJoinedUsersSetForRoomsSQL = "" + const selectJoinedUsersSetForRoomsSQL = "" +
"SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id IN ($1) AND type = 'm.room.member' and content_value = 'join' GROUP BY state_key" "SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id IN ($1) AND type = 'm.room.member' and content_value = 'join' GROUP BY state_key"
const selectKnownRoomsSQL = "" +
"SELECT DISTINCT room_id FROM currentstate_current_room_state"
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is // selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
// joined to. Since this information is used to populate the user directory, we will // joined to. Since this information is used to populate the user directory, we will
// only return users that the user would ordinarily be able to see anyway. // only return users that the user would ordinarily be able to see anyway.
@ -86,6 +89,7 @@ type currentRoomStateStatements struct {
selectRoomIDsWithMembershipStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt
selectStateEventStmt *sql.Stmt selectStateEventStmt *sql.Stmt
selectJoinedUsersSetForRoomsStmt *sql.Stmt selectJoinedUsersSetForRoomsStmt *sql.Stmt
selectKnownRoomsStmt *sql.Stmt
selectKnownUsersStmt *sql.Stmt selectKnownUsersStmt *sql.Stmt
} }
@ -113,6 +117,9 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error)
if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil { if s.selectJoinedUsersSetForRoomsStmt, err = db.Prepare(selectJoinedUsersSetForRoomsSQL); err != nil {
return nil, err return nil, err
} }
if s.selectKnownRoomsStmt, err = db.Prepare(selectKnownRoomsSQL); err != nil {
return nil, err
}
if s.selectKnownUsersStmt, err = db.Prepare(selectKnownUsersSQL); err != nil { if s.selectKnownUsersStmt, err = db.Prepare(selectKnownUsersSQL); err != nil {
return nil, err return nil, err
} }
@ -345,3 +352,20 @@ func (s *currentRoomStateStatements) SelectKnownUsers(ctx context.Context, userI
} }
return result, rows.Err() return result, rows.Err()
} }
func (s *currentRoomStateStatements) SelectKnownRooms(ctx context.Context) ([]string, error) {
rows, err := s.selectKnownRoomsStmt.QueryContext(ctx)
if err != nil {
return nil, err
}
result := []string{}
defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownRooms: rows.close() failed")
for rows.Next() {
var roomID string
if err := rows.Scan(&roomID); err != nil {
return nil, err
}
result = append(result, roomID)
}
return result, rows.Err()
}

View file

@ -41,6 +41,8 @@ type CurrentRoomState interface {
SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
// SelectKnownUsers searches all users that userID knows about. // SelectKnownUsers searches all users that userID knows about.
SelectKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) SelectKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
// SelectKnownRooms returns all rooms that we know about.
SelectKnownRooms(ctx context.Context) ([]string, error)
} }
// StrippedEvent represents a stripped event for returning extracted content values. // StrippedEvent represents a stripped event for returning extracted content values.

View file

@ -41,7 +41,6 @@ func AddPublicRoutes(
stateAPI currentstateAPI.CurrentStateInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI,
keyAPI keyserverAPI.KeyInternalAPI, keyAPI keyserverAPI.KeyInternalAPI,
) { ) {
routing.Setup( routing.Setup(
router, cfg, rsAPI, router, cfg, rsAPI,
eduAPI, federationSenderAPI, keyRing, eduAPI, federationSenderAPI, keyRing,

View file

@ -82,7 +82,7 @@ func Setup(
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
return Send( return Send(
httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]), httpReq, request, gomatrixserverlib.TransactionID(vars["txnID"]),
cfg, rsAPI, eduAPI, keyAPI, keys, federation, cfg, rsAPI, eduAPI, keyAPI, stateAPI, keys, federation,
) )
}, },
)).Methods(http.MethodPut, http.MethodOptions) )).Methods(http.MethodPut, http.MethodOptions)
@ -90,6 +90,12 @@ func Setup(
v1fedmux.Handle("/invite/{roomID}/{eventID}", httputil.MakeFedAPI( v1fedmux.Handle("/invite/{roomID}/{eventID}", httputil.MakeFedAPI(
"federation_invite", cfg.Matrix.ServerName, keys, wakeup, "federation_invite", cfg.Matrix.ServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Forbidden by server ACLs"),
}
}
res := InviteV1( res := InviteV1(
httpReq, request, vars["roomID"], vars["eventID"], httpReq, request, vars["roomID"], vars["eventID"],
cfg, rsAPI, keys, cfg, rsAPI, keys,
@ -106,6 +112,12 @@ func Setup(
v2fedmux.Handle("/invite/{roomID}/{eventID}", httputil.MakeFedAPI( v2fedmux.Handle("/invite/{roomID}/{eventID}", httputil.MakeFedAPI(
"federation_invite", cfg.Matrix.ServerName, keys, wakeup, "federation_invite", cfg.Matrix.ServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Forbidden by server ACLs"),
}
}
return InviteV2( return InviteV2(
httpReq, request, vars["roomID"], vars["eventID"], httpReq, request, vars["roomID"], vars["eventID"],
cfg, rsAPI, keys, cfg, rsAPI, keys,
@ -140,6 +152,12 @@ func Setup(
v1fedmux.Handle("/state/{roomID}", httputil.MakeFedAPI( v1fedmux.Handle("/state/{roomID}", httputil.MakeFedAPI(
"federation_get_state", cfg.Matrix.ServerName, keys, wakeup, "federation_get_state", cfg.Matrix.ServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Forbidden by server ACLs"),
}
}
return GetState( return GetState(
httpReq.Context(), request, rsAPI, vars["roomID"], httpReq.Context(), request, rsAPI, vars["roomID"],
) )
@ -149,6 +167,12 @@ func Setup(
v1fedmux.Handle("/state_ids/{roomID}", httputil.MakeFedAPI( v1fedmux.Handle("/state_ids/{roomID}", httputil.MakeFedAPI(
"federation_get_state_ids", cfg.Matrix.ServerName, keys, wakeup, "federation_get_state_ids", cfg.Matrix.ServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Forbidden by server ACLs"),
}
}
return GetStateIDs( return GetStateIDs(
httpReq.Context(), request, rsAPI, vars["roomID"], httpReq.Context(), request, rsAPI, vars["roomID"],
) )
@ -158,6 +182,12 @@ func Setup(
v1fedmux.Handle("/event_auth/{roomID}/{eventID}", httputil.MakeFedAPI( v1fedmux.Handle("/event_auth/{roomID}/{eventID}", httputil.MakeFedAPI(
"federation_get_event_auth", cfg.Matrix.ServerName, keys, wakeup, "federation_get_event_auth", cfg.Matrix.ServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Forbidden by server ACLs"),
}
}
return GetEventAuth( return GetEventAuth(
httpReq.Context(), request, rsAPI, vars["roomID"], vars["eventID"], httpReq.Context(), request, rsAPI, vars["roomID"], vars["eventID"],
) )
@ -194,6 +224,12 @@ func Setup(
v1fedmux.Handle("/make_join/{roomID}/{eventID}", httputil.MakeFedAPI( v1fedmux.Handle("/make_join/{roomID}/{eventID}", httputil.MakeFedAPI(
"federation_make_join", cfg.Matrix.ServerName, keys, wakeup, "federation_make_join", cfg.Matrix.ServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Forbidden by server ACLs"),
}
}
roomID := vars["roomID"] roomID := vars["roomID"]
eventID := vars["eventID"] eventID := vars["eventID"]
queryVars := httpReq.URL.Query() queryVars := httpReq.URL.Query()
@ -219,6 +255,12 @@ func Setup(
v1fedmux.Handle("/send_join/{roomID}/{eventID}", httputil.MakeFedAPI( v1fedmux.Handle("/send_join/{roomID}/{eventID}", httputil.MakeFedAPI(
"federation_send_join", cfg.Matrix.ServerName, keys, wakeup, "federation_send_join", cfg.Matrix.ServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Forbidden by server ACLs"),
}
}
roomID := vars["roomID"] roomID := vars["roomID"]
eventID := vars["eventID"] eventID := vars["eventID"]
res := SendJoin( res := SendJoin(
@ -245,6 +287,12 @@ func Setup(
v2fedmux.Handle("/send_join/{roomID}/{eventID}", httputil.MakeFedAPI( v2fedmux.Handle("/send_join/{roomID}/{eventID}", httputil.MakeFedAPI(
"federation_send_join", cfg.Matrix.ServerName, keys, wakeup, "federation_send_join", cfg.Matrix.ServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Forbidden by server ACLs"),
}
}
roomID := vars["roomID"] roomID := vars["roomID"]
eventID := vars["eventID"] eventID := vars["eventID"]
return SendJoin( return SendJoin(
@ -256,6 +304,12 @@ func Setup(
v1fedmux.Handle("/make_leave/{roomID}/{eventID}", httputil.MakeFedAPI( v1fedmux.Handle("/make_leave/{roomID}/{eventID}", httputil.MakeFedAPI(
"federation_make_leave", cfg.Matrix.ServerName, keys, wakeup, "federation_make_leave", cfg.Matrix.ServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Forbidden by server ACLs"),
}
}
roomID := vars["roomID"] roomID := vars["roomID"]
eventID := vars["eventID"] eventID := vars["eventID"]
return MakeLeave( return MakeLeave(
@ -264,9 +318,47 @@ func Setup(
}, },
)).Methods(http.MethodGet) )).Methods(http.MethodGet)
v1fedmux.Handle("/send_leave/{roomID}/{eventID}", httputil.MakeFedAPI(
"federation_send_leave", cfg.Matrix.ServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Forbidden by server ACLs"),
}
}
roomID := vars["roomID"]
eventID := vars["eventID"]
res := SendLeave(
httpReq, request, cfg, rsAPI, keys, roomID, eventID,
)
// not all responses get wrapped in [code, body]
var body interface{}
body = []interface{}{
res.Code, res.JSON,
}
jerr, ok := res.JSON.(*jsonerror.MatrixError)
if ok {
body = jerr
}
return util.JSONResponse{
Headers: res.Headers,
Code: res.Code,
JSON: body,
}
},
)).Methods(http.MethodPut)
v2fedmux.Handle("/send_leave/{roomID}/{eventID}", httputil.MakeFedAPI( v2fedmux.Handle("/send_leave/{roomID}/{eventID}", httputil.MakeFedAPI(
"federation_send_leave", cfg.Matrix.ServerName, keys, wakeup, "federation_send_leave", cfg.Matrix.ServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Forbidden by server ACLs"),
}
}
roomID := vars["roomID"] roomID := vars["roomID"]
eventID := vars["eventID"] eventID := vars["eventID"]
return SendLeave( return SendLeave(
@ -285,6 +377,12 @@ func Setup(
v1fedmux.Handle("/get_missing_events/{roomID}", httputil.MakeFedAPI( v1fedmux.Handle("/get_missing_events/{roomID}", httputil.MakeFedAPI(
"federation_get_missing_events", cfg.Matrix.ServerName, keys, wakeup, "federation_get_missing_events", cfg.Matrix.ServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Forbidden by server ACLs"),
}
}
return GetMissingEvents(httpReq, request, rsAPI, vars["roomID"]) return GetMissingEvents(httpReq, request, rsAPI, vars["roomID"])
}, },
)).Methods(http.MethodPost) )).Methods(http.MethodPost)
@ -292,6 +390,12 @@ func Setup(
v1fedmux.Handle("/backfill/{roomID}", httputil.MakeFedAPI( v1fedmux.Handle("/backfill/{roomID}", httputil.MakeFedAPI(
"federation_backfill", cfg.Matrix.ServerName, keys, wakeup, "federation_backfill", cfg.Matrix.ServerName, keys, wakeup,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse {
if currentstateAPI.IsServerBannedFromRoom(httpReq.Context(), stateAPI, vars["roomID"], request.Origin()) {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: jsonerror.Forbidden("Forbidden by server ACLs"),
}
}
return Backfill(httpReq, request, rsAPI, vars["roomID"], cfg) return Backfill(httpReq, request, rsAPI, vars["roomID"], cfg)
}, },
)).Methods(http.MethodGet) )).Methods(http.MethodGet)

View file

@ -21,6 +21,7 @@ import (
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api"
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api" eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
keyapi "github.com/matrix-org/dendrite/keyserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api"
@ -39,6 +40,7 @@ func Send(
rsAPI api.RoomserverInternalAPI, rsAPI api.RoomserverInternalAPI,
eduAPI eduserverAPI.EDUServerInputAPI, eduAPI eduserverAPI.EDUServerInputAPI,
keyAPI keyapi.KeyInternalAPI, keyAPI keyapi.KeyInternalAPI,
stateAPI currentstateAPI.CurrentStateInternalAPI,
keys gomatrixserverlib.JSONVerifier, keys gomatrixserverlib.JSONVerifier,
federation *gomatrixserverlib.FederationClient, federation *gomatrixserverlib.FederationClient,
) util.JSONResponse { ) util.JSONResponse {
@ -46,6 +48,7 @@ func Send(
context: httpReq.Context(), context: httpReq.Context(),
rsAPI: rsAPI, rsAPI: rsAPI,
eduAPI: eduAPI, eduAPI: eduAPI,
stateAPI: stateAPI,
keys: keys, keys: keys,
federation: federation, federation: federation,
haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent),
@ -104,6 +107,7 @@ type txnReq struct {
rsAPI api.RoomserverInternalAPI rsAPI api.RoomserverInternalAPI
eduAPI eduserverAPI.EDUServerInputAPI eduAPI eduserverAPI.EDUServerInputAPI
keyAPI keyapi.KeyInternalAPI keyAPI keyapi.KeyInternalAPI
stateAPI currentstateAPI.CurrentStateInternalAPI
keys gomatrixserverlib.JSONVerifier keys gomatrixserverlib.JSONVerifier
federation txnFederationClient federation txnFederationClient
// local cache of events for auth checks, etc - this may include events // local cache of events for auth checks, etc - this may include events
@ -164,6 +168,12 @@ func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, *util.JSONRe
util.GetLogger(t.context).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %s", string(pdu)) util.GetLogger(t.context).WithError(err).Warnf("Transaction: Failed to parse event JSON of event %s", string(pdu))
continue continue
} }
if currentstateAPI.IsServerBannedFromRoom(t.context, t.stateAPI, event.RoomID(), t.Origin) {
results[event.EventID()] = gomatrixserverlib.PDUResult{
Error: "Forbidden by server ACLs",
}
continue
}
if err = gomatrixserverlib.VerifyAllEventSignatures(t.context, []gomatrixserverlib.Event{event}, t.keys); err != nil { if err = gomatrixserverlib.VerifyAllEventSignatures(t.context, []gomatrixserverlib.Event{event}, t.keys); err != nil {
util.GetLogger(t.context).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) util.GetLogger(t.context).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID())
results[event.EventID()] = gomatrixserverlib.PDUResult{ results[event.EventID()] = gomatrixserverlib.PDUResult{

View file

@ -8,6 +8,7 @@ import (
"testing" "testing"
"time" "time"
currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api"
eduAPI "github.com/matrix-org/dendrite/eduserver/api" eduAPI "github.com/matrix-org/dendrite/eduserver/api"
fsAPI "github.com/matrix-org/dendrite/federationsender/api" fsAPI "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/internal/test"
@ -294,6 +295,33 @@ func (t *testRoomserverAPI) RemoveRoomAlias(
return fmt.Errorf("not implemented") return fmt.Errorf("not implemented")
} }
type testStateAPI struct {
}
func (t *testStateAPI) QueryCurrentState(ctx context.Context, req *currentstateAPI.QueryCurrentStateRequest, res *currentstateAPI.QueryCurrentStateResponse) error {
return nil
}
func (t *testStateAPI) QueryRoomsForUser(ctx context.Context, req *currentstateAPI.QueryRoomsForUserRequest, res *currentstateAPI.QueryRoomsForUserResponse) error {
return fmt.Errorf("not implemented")
}
func (t *testStateAPI) QueryBulkStateContent(ctx context.Context, req *currentstateAPI.QueryBulkStateContentRequest, res *currentstateAPI.QueryBulkStateContentResponse) error {
return fmt.Errorf("not implemented")
}
func (t *testStateAPI) QuerySharedUsers(ctx context.Context, req *currentstateAPI.QuerySharedUsersRequest, res *currentstateAPI.QuerySharedUsersResponse) error {
return fmt.Errorf("not implemented")
}
func (t *testStateAPI) QueryKnownUsers(ctx context.Context, req *currentstateAPI.QueryKnownUsersRequest, res *currentstateAPI.QueryKnownUsersResponse) error {
return fmt.Errorf("not implemented")
}
func (t *testStateAPI) QueryServerBannedFromRoom(ctx context.Context, req *currentstateAPI.QueryServerBannedFromRoomRequest, res *currentstateAPI.QueryServerBannedFromRoomResponse) error {
return nil
}
type txnFedClient struct { type txnFedClient struct {
state map[string]gomatrixserverlib.RespState // event_id to response state map[string]gomatrixserverlib.RespState // event_id to response
stateIDs map[string]gomatrixserverlib.RespStateIDs // event_id to response stateIDs map[string]gomatrixserverlib.RespStateIDs // event_id to response
@ -338,11 +366,12 @@ func (c *txnFedClient) LookupMissingEvents(ctx context.Context, s gomatrixserver
return c.getMissingEvents(missing) return c.getMissingEvents(missing)
} }
func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederationClient, pdus []json.RawMessage) *txnReq { func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, stateAPI currentstateAPI.CurrentStateInternalAPI, fedClient txnFederationClient, pdus []json.RawMessage) *txnReq {
t := &txnReq{ t := &txnReq{
context: context.Background(), context: context.Background(),
rsAPI: rsAPI, rsAPI: rsAPI,
eduAPI: &testEDUProducer{}, eduAPI: &testEDUProducer{},
stateAPI: stateAPI,
keys: &test.NopJSONVerifier{}, keys: &test.NopJSONVerifier{},
federation: fedClient, federation: fedClient,
haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent),
@ -422,10 +451,11 @@ func TestBasicTransaction(t *testing.T) {
} }
}, },
} }
stateAPI := &testStateAPI{}
pdus := []json.RawMessage{ pdus := []json.RawMessage{
testData[len(testData)-1], // a message event testData[len(testData)-1], // a message event
} }
txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) txn := mustCreateTransaction(rsAPI, stateAPI, &txnFedClient{}, pdus)
mustProcessTransaction(t, txn, nil) mustProcessTransaction(t, txn, nil)
assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]})
} }
@ -444,10 +474,11 @@ func TestTransactionFailAuthChecks(t *testing.T) {
} }
}, },
} }
stateAPI := &testStateAPI{}
pdus := []json.RawMessage{ pdus := []json.RawMessage{
testData[len(testData)-1], // a message event testData[len(testData)-1], // a message event
} }
txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) txn := mustCreateTransaction(rsAPI, stateAPI, &txnFedClient{}, pdus)
mustProcessTransaction(t, txn, []string{ mustProcessTransaction(t, txn, []string{
// expect the event to have an error // expect the event to have an error
testEvents[len(testEvents)-1].EventID(), testEvents[len(testEvents)-1].EventID(),
@ -502,6 +533,8 @@ func TestTransactionFetchMissingPrevEvents(t *testing.T) {
}, },
} }
stateAPI := &testStateAPI{}
cli := &txnFedClient{ cli := &txnFedClient{
getMissingEvents: func(missing gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) { getMissingEvents: func(missing gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) {
if !reflect.DeepEqual(missing.EarliestEvents, []string{haveEvent.EventID()}) { if !reflect.DeepEqual(missing.EarliestEvents, []string{haveEvent.EventID()}) {
@ -521,7 +554,7 @@ func TestTransactionFetchMissingPrevEvents(t *testing.T) {
pdus := []json.RawMessage{ pdus := []json.RawMessage{
inputEvent.JSON(), inputEvent.JSON(),
} }
txn := mustCreateTransaction(rsAPI, cli, pdus) txn := mustCreateTransaction(rsAPI, stateAPI, cli, pdus)
mustProcessTransaction(t, txn, nil) mustProcessTransaction(t, txn, nil)
assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{prevEvent, inputEvent}) assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{prevEvent, inputEvent})
} }
@ -671,10 +704,12 @@ func TestTransactionFetchMissingStateByStateIDs(t *testing.T) {
}, },
} }
stateAPI := &testStateAPI{}
pdus := []json.RawMessage{ pdus := []json.RawMessage{
eventD.JSON(), eventD.JSON(),
} }
txn := mustCreateTransaction(rsAPI, cli, pdus) txn := mustCreateTransaction(rsAPI, stateAPI, cli, pdus)
mustProcessTransaction(t, txn, nil) mustProcessTransaction(t, txn, nil)
assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{eventB, eventC, eventD}) assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{eventB, eventC, eventD})
} }

View file

@ -114,6 +114,10 @@ func (s *mockCurrentStateAPI) QuerySharedUsers(ctx context.Context, req *api.Que
return nil return nil
} }
func (t *mockCurrentStateAPI) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse) error {
return nil
}
type wantCatchup struct { type wantCatchup struct {
hasNew bool hasNew bool
changed []string changed []string

View file

@ -436,3 +436,15 @@ User directory correctly update on display name change
User in shared private room does appear in user directory User in shared private room does appear in user directory
User in dir while user still shares private rooms User in dir while user still shares private rooms
Can get 'm.room.name' state for a departed room (SPEC-216) Can get 'm.room.name' state for a departed room (SPEC-216)
Banned servers cannot send events
Banned servers cannot /make_join
Banned servers cannot /send_join
Banned servers cannot /make_leave
Banned servers cannot /send_leave
Banned servers cannot /invite
Banned servers cannot get room state
Banned servers cannot /event_auth
Banned servers cannot get missing events
Banned servers cannot get room state ids
Banned servers cannot backfill
Inbound /v1/send_leave rejects leaves from other servers