Merge branch 'master' into kegan/history-vis

This commit is contained in:
Neil Alexander 2020-09-23 11:10:02 +01:00
commit 7b712865c6
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
65 changed files with 1457 additions and 245 deletions

View file

@ -215,7 +215,8 @@ func writeToRoomServer(input []string, roomserverURL string) error {
if err != nil { if err != nil {
return err return err
} }
return x.InputRoomEvents(context.Background(), &request, &response) x.InputRoomEvents(context.Background(), &request, &response)
return response.Err()
} }
// testRoomserver is used to run integration tests against a single roomserver. // testRoomserver is used to run integration tests against a single roomserver.

View file

@ -38,9 +38,6 @@ global:
# The path to the signing private key file, used to sign requests and events. # The path to the signing private key file, used to sign requests and events.
private_key: matrix_key.pem private_key: matrix_key.pem
# A unique identifier for this private key. Must start with the prefix "ed25519:".
key_id: ed25519:auto
# How long a remote server can cache our server signing key before requesting it # How long a remote server can cache our server signing key before requesting it
# again. Increasing this number will reduce the number of requests made by other # again. Increasing this number will reduce the number of requests made by other
# servers for our key but increases the period that a compromised key will be # servers for our key but increases the period that a compromised key will be

View file

@ -19,11 +19,14 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/jsonerror"
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
) )
@ -160,3 +163,62 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver
return &keys, nil return &keys, nil
} }
func NotaryKeys(
httpReq *http.Request, cfg *config.FederationAPI,
fsAPI federationSenderAPI.FederationSenderInternalAPI,
req *gomatrixserverlib.PublicKeyNotaryLookupRequest,
) util.JSONResponse {
if req == nil {
req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{}
if reqErr := httputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil {
return *reqErr
}
}
var response struct {
ServerKeys []json.RawMessage `json:"server_keys"`
}
response.ServerKeys = []json.RawMessage{}
for serverName := range req.ServerKeys {
var keys *gomatrixserverlib.ServerKeys
if serverName == cfg.Matrix.ServerName {
if k, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil {
keys = k
} else {
return util.ErrorResponse(err)
}
} else {
if k, err := fsAPI.GetServerKeys(httpReq.Context(), serverName); err == nil {
keys = &k
} else {
return util.ErrorResponse(err)
}
}
if keys == nil {
continue
}
j, err := json.Marshal(keys)
if err != nil {
logrus.WithError(err).Errorf("Failed to marshal %q response", serverName)
return jsonerror.InternalServerError()
}
js, err := gomatrixserverlib.SignJSON(
string(cfg.Matrix.ServerName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, j,
)
if err != nil {
logrus.WithError(err).Errorf("Failed to sign %q response", serverName)
return jsonerror.InternalServerError()
}
response.ServerKeys = append(response.ServerKeys, js)
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: response,
}
}

View file

@ -61,6 +61,26 @@ func Setup(
return LocalKeys(cfg) return LocalKeys(cfg)
}) })
notaryKeys := httputil.MakeExternalAPI("notarykeys", func(req *http.Request) util.JSONResponse {
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
if err != nil {
return util.ErrorResponse(err)
}
var pkReq *gomatrixserverlib.PublicKeyNotaryLookupRequest
serverName := gomatrixserverlib.ServerName(vars["serverName"])
keyID := gomatrixserverlib.KeyID(vars["keyID"])
if serverName != "" && keyID != "" {
pkReq = &gomatrixserverlib.PublicKeyNotaryLookupRequest{
ServerKeys: map[gomatrixserverlib.ServerName]map[gomatrixserverlib.KeyID]gomatrixserverlib.PublicKeyNotaryQueryCriteria{
serverName: {
keyID: gomatrixserverlib.PublicKeyNotaryQueryCriteria{},
},
},
}
}
return NotaryKeys(req, cfg, fsAPI, pkReq)
})
// Ignore the {keyID} argument as we only have a single server key so we always // Ignore the {keyID} argument as we only have a single server key so we always
// return that key. // return that key.
// Even if we had more than one server key, we would probably still ignore the // Even if we had more than one server key, we would probably still ignore the
@ -68,6 +88,8 @@ func Setup(
v2keysmux.Handle("/server/{keyID}", localKeys).Methods(http.MethodGet) v2keysmux.Handle("/server/{keyID}", localKeys).Methods(http.MethodGet)
v2keysmux.Handle("/server/", localKeys).Methods(http.MethodGet) v2keysmux.Handle("/server/", localKeys).Methods(http.MethodGet)
v2keysmux.Handle("/server", localKeys).Methods(http.MethodGet) v2keysmux.Handle("/server", localKeys).Methods(http.MethodGet)
v2keysmux.Handle("/query", notaryKeys).Methods(http.MethodPost)
v2keysmux.Handle("/query/{serverName}/{keyID}", notaryKeys).Methods(http.MethodGet)
v1fedmux.Handle("/send/{txnID}", httputil.MakeFedAPI( v1fedmux.Handle("/send/{txnID}", httputil.MakeFedAPI(
"federation_send", cfg.Matrix.ServerName, keys, wakeup, "federation_send", cfg.Matrix.ServerName, keys, wakeup,

View file

@ -372,12 +372,9 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, is
return t.processEventWithMissingState(ctx, e, stateResp.RoomVersion, isInboundTxn) return t.processEventWithMissingState(ctx, e, stateResp.RoomVersion, isInboundTxn)
} }
// Check that the event is allowed by the state at the event. // pass the event to the roomserver which will do auth checks
if err := checkAllowedByState(e, gomatrixserverlib.UnwrapEventHeaders(stateResp.StateEvents)); err != nil { // If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently
return err // discarded by the caller of this function
}
// pass the event to the roomserver
return api.SendEvents( return api.SendEvents(
context.Background(), context.Background(),
t.rsAPI, t.rsAPI,

View file

@ -89,12 +89,11 @@ func (t *testRoomserverAPI) InputRoomEvents(
ctx context.Context, ctx context.Context,
request *api.InputRoomEventsRequest, request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse, response *api.InputRoomEventsResponse,
) error { ) {
t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...) t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...)
for _, ire := range request.InputRoomEvents { for _, ire := range request.InputRoomEvents {
fmt.Println("InputRoomEvents: ", ire.Event.EventID()) fmt.Println("InputRoomEvents: ", ire.Event.EventID())
} }
return nil
} }
func (t *testRoomserverAPI) PerformInvite( func (t *testRoomserverAPI) PerformInvite(
@ -461,7 +460,8 @@ func TestBasicTransaction(t *testing.T) {
assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]})
} }
// The purpose of this test is to check that if the event received fails auth checks the transaction is failed. // The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver
// as it does the auth check.
func TestTransactionFailAuthChecks(t *testing.T) { func TestTransactionFailAuthChecks(t *testing.T) {
rsAPI := &testRoomserverAPI{ rsAPI := &testRoomserverAPI{
queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse {
@ -479,11 +479,9 @@ func TestTransactionFailAuthChecks(t *testing.T) {
testData[len(testData)-1], // a message event testData[len(testData)-1], // a message event
} }
txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus)
mustProcessTransaction(t, txn, []string{ mustProcessTransaction(t, txn, []string{})
// expect the event to have an error // expect message to be sent to the roomserver
testEvents[len(testEvents)-1].EventID(), assertInputRoomEvents(t, rsAPI.inputRoomEvents, []gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]})
})
assertInputRoomEvents(t, rsAPI.inputRoomEvents, nil) // expect no messages to be sent to the roomserver
} }
// The purpose of this test is to make sure that when an event is received for which we do not know the prev_events, // The purpose of this test is to make sure that when an event is received for which we do not know the prev_events,

View file

@ -20,6 +20,8 @@ type FederationClient interface {
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error)
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
GetServerKeys(ctx context.Context, matrixServer gomatrixserverlib.ServerName) (gomatrixserverlib.ServerKeys, error)
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
} }
// FederationClientError is returned from FederationClient methods in the event of a problem. // FederationClientError is returned from FederationClient methods in the event of a problem.

View file

@ -2,6 +2,7 @@ package internal
import ( import (
"context" "context"
"sync"
"time" "time"
"github.com/matrix-org/dendrite/federationsender/api" "github.com/matrix-org/dendrite/federationsender/api"
@ -23,6 +24,7 @@ type FederationSenderInternalAPI struct {
federation *gomatrixserverlib.FederationClient federation *gomatrixserverlib.FederationClient
keyRing *gomatrixserverlib.KeyRing keyRing *gomatrixserverlib.KeyRing
queues *queue.OutgoingQueues queues *queue.OutgoingQueues
joins sync.Map // joins currently in progress
} }
func NewFederationSenderInternalAPI( func NewFederationSenderInternalAPI(
@ -187,3 +189,27 @@ func (a *FederationSenderInternalAPI) GetEvent(
} }
return ires.(gomatrixserverlib.Transaction), nil return ires.(gomatrixserverlib.Transaction), nil
} }
func (a *FederationSenderInternalAPI) GetServerKeys(
ctx context.Context, s gomatrixserverlib.ServerName,
) (gomatrixserverlib.ServerKeys, error) {
ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.GetServerKeys(ctx, s)
})
if err != nil {
return gomatrixserverlib.ServerKeys{}, err
}
return ires.(gomatrixserverlib.ServerKeys), nil
}
func (a *FederationSenderInternalAPI) LookupServerKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) ([]gomatrixserverlib.ServerKeys, error) {
ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.LookupServerKeys(ctx, s, keyRequests)
})
if err != nil {
return []gomatrixserverlib.ServerKeys{}, err
}
return ires.([]gomatrixserverlib.ServerKeys), nil
}

View file

@ -37,12 +37,32 @@ func (r *FederationSenderInternalAPI) PerformDirectoryLookup(
return nil return nil
} }
type federatedJoin struct {
UserID string
RoomID string
}
// PerformJoinRequest implements api.FederationSenderInternalAPI // PerformJoinRequest implements api.FederationSenderInternalAPI
func (r *FederationSenderInternalAPI) PerformJoin( func (r *FederationSenderInternalAPI) PerformJoin(
ctx context.Context, ctx context.Context,
request *api.PerformJoinRequest, request *api.PerformJoinRequest,
response *api.PerformJoinResponse, response *api.PerformJoinResponse,
) { ) {
// Check that a join isn't already in progress for this user/room.
j := federatedJoin{request.UserID, request.RoomID}
if _, found := r.joins.Load(j); found {
response.LastError = &gomatrix.HTTPError{
Code: 429,
Message: `{
"errcode": "M_LIMIT_EXCEEDED",
"error": "There is already a federated join to this room in progress. Please wait for it to finish."
}`, // TODO: Why do none of our error types play nicely with each other?
}
return
}
r.joins.Store(j, nil)
defer r.joins.Delete(j)
// Look up the supported room versions. // Look up the supported room versions.
var supportedVersions []gomatrixserverlib.RoomVersion var supportedVersions []gomatrixserverlib.RoomVersion
for version := range version.SupportedRoomVersions() { for version := range version.SupportedRoomVersions() {
@ -98,7 +118,10 @@ func (r *FederationSenderInternalAPI) PerformJoin(
response.LastError = &gomatrix.HTTPError{ response.LastError = &gomatrix.HTTPError{
Code: 0, Code: 0,
WrappedError: nil, WrappedError: nil,
Message: lastErr.Error(), Message: "Unknown HTTP error",
}
if lastErr != nil {
response.LastError.Message = lastErr.Error()
} }
} }
@ -183,27 +206,47 @@ func (r *FederationSenderInternalAPI) performJoinUsingServer(
} }
r.statistics.ForServer(serverName).Success() r.statistics.ForServer(serverName).Success()
// Check that the send_join response was valid. // Process the join response in a goroutine. The idea here is
joinCtx := perform.JoinContext(r.federation, r.keyRing) // that we'll try and wait for as long as possible for the work
respState, err := joinCtx.CheckSendJoinResponse( // to complete, but if the client does give up waiting, we'll
ctx, event, serverName, respMakeJoin, respSendJoin, // still continue to process the join anyway so that we don't
) // waste the effort.
if err != nil { var cancel context.CancelFunc
return fmt.Errorf("joinCtx.CheckSendJoinResponse: %w", err) ctx, cancel = context.WithCancel(context.Background())
} go func() {
defer cancel()
// If we successfully performed a send_join above then the other // Check that the send_join response was valid.
// server now thinks we're a part of the room. Send the newly joinCtx := perform.JoinContext(r.federation, r.keyRing)
// returned state to the roomserver to update our local view. respState, err := joinCtx.CheckSendJoinResponse(
if err = roomserverAPI.SendEventWithState( ctx, event, serverName, respMakeJoin, respSendJoin,
ctx, r.rsAPI, )
respState, if err != nil {
event.Headered(respMakeJoin.RoomVersion), logrus.WithFields(logrus.Fields{
nil, "room_id": roomID,
); err != nil { "user_id": userID,
return fmt.Errorf("r.producer.SendEventWithState: %w", err) }).WithError(err).Error("Failed to process room join response")
} return
}
// If we successfully performed a send_join above then the other
// server now thinks we're a part of the room. Send the newly
// returned state to the roomserver to update our local view.
if err = roomserverAPI.SendEventWithRewrite(
ctx, r.rsAPI,
respState,
event.Headered(respMakeJoin.RoomVersion),
nil,
); err != nil {
logrus.WithFields(logrus.Fields{
"room_id": roomID,
"user_id": userID,
}).WithError(err).Error("Failed to send room join response to roomserver")
return
}
}()
<-ctx.Done()
return nil return nil
} }

View file

@ -23,13 +23,15 @@ const (
FederationSenderPerformServersAlivePath = "/federationsender/performServersAlive" FederationSenderPerformServersAlivePath = "/federationsender/performServersAlive"
FederationSenderPerformBroadcastEDUPath = "/federationsender/performBroadcastEDU" FederationSenderPerformBroadcastEDUPath = "/federationsender/performBroadcastEDU"
FederationSenderGetUserDevicesPath = "/federationsender/client/getUserDevices" FederationSenderGetUserDevicesPath = "/federationsender/client/getUserDevices"
FederationSenderClaimKeysPath = "/federationsender/client/claimKeys" FederationSenderClaimKeysPath = "/federationsender/client/claimKeys"
FederationSenderQueryKeysPath = "/federationsender/client/queryKeys" FederationSenderQueryKeysPath = "/federationsender/client/queryKeys"
FederationSenderBackfillPath = "/federationsender/client/backfill" FederationSenderBackfillPath = "/federationsender/client/backfill"
FederationSenderLookupStatePath = "/federationsender/client/lookupState" FederationSenderLookupStatePath = "/federationsender/client/lookupState"
FederationSenderLookupStateIDsPath = "/federationsender/client/lookupStateIDs" FederationSenderLookupStateIDsPath = "/federationsender/client/lookupStateIDs"
FederationSenderGetEventPath = "/federationsender/client/getEvent" FederationSenderGetEventPath = "/federationsender/client/getEvent"
FederationSenderGetServerKeysPath = "/federationsender/client/getServerKeys"
FederationSenderLookupServerKeysPath = "/federationsender/client/lookupServerKeys"
) )
// NewFederationSenderClient creates a FederationSenderInternalAPI implemented by talking to a HTTP POST API. // NewFederationSenderClient creates a FederationSenderInternalAPI implemented by talking to a HTTP POST API.
@ -358,3 +360,59 @@ func (h *httpFederationSenderInternalAPI) GetEvent(
} }
return *response.Res, nil return *response.Res, nil
} }
type getServerKeys struct {
S gomatrixserverlib.ServerName
ServerKeys gomatrixserverlib.ServerKeys
Err *api.FederationClientError
}
func (h *httpFederationSenderInternalAPI) GetServerKeys(
ctx context.Context, s gomatrixserverlib.ServerName,
) (gomatrixserverlib.ServerKeys, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "GetServerKeys")
defer span.Finish()
request := getServerKeys{
S: s,
}
var response getServerKeys
apiURL := h.federationSenderURL + FederationSenderGetServerKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return gomatrixserverlib.ServerKeys{}, err
}
if response.Err != nil {
return gomatrixserverlib.ServerKeys{}, response.Err
}
return response.ServerKeys, nil
}
type lookupServerKeys struct {
S gomatrixserverlib.ServerName
KeyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp
ServerKeys []gomatrixserverlib.ServerKeys
Err *api.FederationClientError
}
func (h *httpFederationSenderInternalAPI) LookupServerKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) ([]gomatrixserverlib.ServerKeys, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupServerKeys")
defer span.Finish()
request := lookupServerKeys{
S: s,
KeyRequests: keyRequests,
}
var response lookupServerKeys
apiURL := h.federationSenderURL + FederationSenderLookupServerKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return []gomatrixserverlib.ServerKeys{}, err
}
if response.Err != nil {
return []gomatrixserverlib.ServerKeys{}, response.Err
}
return response.ServerKeys, nil
}

View file

@ -263,4 +263,48 @@ func AddRoutes(intAPI api.FederationSenderInternalAPI, internalAPIMux *mux.Route
return util.JSONResponse{Code: http.StatusOK, JSON: request} return util.JSONResponse{Code: http.StatusOK, JSON: request}
}), }),
) )
internalAPIMux.Handle(
FederationSenderGetServerKeysPath,
httputil.MakeInternalAPI("GetServerKeys", func(req *http.Request) util.JSONResponse {
var request getServerKeys
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.GetServerKeys(req.Context(), request.S)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.ServerKeys = res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
)
internalAPIMux.Handle(
FederationSenderLookupServerKeysPath,
httputil.MakeInternalAPI("LookupServerKeys", func(req *http.Request) util.JSONResponse {
var request lookupServerKeys
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.LookupServerKeys(req.Context(), request.S, request.KeyRequests)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.ServerKeys = res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
)
} }

View file

@ -231,13 +231,24 @@ func (oq *destinationQueue) backgroundSend() {
// If we are backing off this server then wait for the // If we are backing off this server then wait for the
// backoff duration to complete first, or until explicitly // backoff duration to complete first, or until explicitly
// told to retry. // told to retry.
if _, giveUp := oq.statistics.BackoffIfRequired(oq.backingOff, oq.interruptBackoff); giveUp { until, blacklisted := oq.statistics.BackoffInfo()
if blacklisted {
// It's been suggested that we should give up because the backoff // It's been suggested that we should give up because the backoff
// has exceeded a maximum allowable value. Clean up the in-memory // has exceeded a maximum allowable value. Clean up the in-memory
// buffers at this point. The PDU clean-up is already on a defer. // buffers at this point. The PDU clean-up is already on a defer.
log.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination) log.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination)
return return
} }
if until != nil && until.After(time.Now()) {
// We haven't backed off yet, so wait for the suggested amount of
// time.
duration := time.Until(*until)
log.Warnf("Backing off %q for %s", oq.destination, duration)
select {
case <-time.After(duration):
case <-oq.interruptBackoff:
}
}
// If we have pending PDUs or EDUs then construct a transaction. // If we have pending PDUs or EDUs then construct a transaction.
if pendingPDUs || pendingEDUs { if pendingPDUs || pendingEDUs {

View file

@ -44,6 +44,7 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS
server = &ServerStatistics{ server = &ServerStatistics{
statistics: s, statistics: s,
serverName: serverName, serverName: serverName,
interrupt: make(chan struct{}),
} }
s.servers[serverName] = server s.servers[serverName] = server
s.mutex.Unlock() s.mutex.Unlock()
@ -68,6 +69,7 @@ type ServerStatistics struct {
backoffStarted atomic.Bool // is the backoff started backoffStarted atomic.Bool // is the backoff started
backoffUntil atomic.Value // time.Time until this backoff interval ends backoffUntil atomic.Value // time.Time until this backoff interval ends
backoffCount atomic.Uint32 // number of times BackoffDuration has been called backoffCount atomic.Uint32 // number of times BackoffDuration has been called
interrupt chan struct{} // interrupts the backoff goroutine
successCounter atomic.Uint32 // how many times have we succeeded? successCounter atomic.Uint32 // how many times have we succeeded?
} }
@ -76,15 +78,24 @@ func (s *ServerStatistics) duration(count uint32) time.Duration {
return time.Second * time.Duration(math.Exp2(float64(count))) return time.Second * time.Duration(math.Exp2(float64(count)))
} }
// cancel will interrupt the currently active backoff.
func (s *ServerStatistics) cancel() {
s.blacklisted.Store(false)
s.backoffUntil.Store(time.Time{})
select {
case s.interrupt <- struct{}{}:
default:
}
}
// Success updates the server statistics with a new successful // Success updates the server statistics with a new successful
// attempt, which increases the sent counter and resets the idle and // attempt, which increases the sent counter and resets the idle and
// failure counters. If a host was blacklisted at this point then // failure counters. If a host was blacklisted at this point then
// we will unblacklist it. // we will unblacklist it.
func (s *ServerStatistics) Success() { func (s *ServerStatistics) Success() {
s.successCounter.Add(1) s.cancel()
s.backoffStarted.Store(false) s.successCounter.Inc()
s.backoffCount.Store(0) s.backoffCount.Store(0)
s.blacklisted.Store(false)
if s.statistics.DB != nil { if s.statistics.DB != nil {
if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil {
logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName)
@ -99,10 +110,30 @@ func (s *ServerStatistics) Success() {
// whether we have blacklisted and therefore to give up. // whether we have blacklisted and therefore to give up.
func (s *ServerStatistics) Failure() (time.Time, bool) { func (s *ServerStatistics) Failure() (time.Time, bool) {
// If we aren't already backing off, this call will start // If we aren't already backing off, this call will start
// a new backoff period. Reset the counter to 0 so that // a new backoff period. Increase the failure counter and
// we backoff only for short periods of time to start with. // start a goroutine which will wait out the backoff and
// unset the backoffStarted flag when done.
if s.backoffStarted.CAS(false, true) { if s.backoffStarted.CAS(false, true) {
s.backoffCount.Store(0) if s.backoffCount.Inc() >= s.statistics.FailuresUntilBlacklist {
s.blacklisted.Store(true)
if s.statistics.DB != nil {
if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil {
logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName)
}
}
return time.Time{}, true
}
go func() {
until, ok := s.backoffUntil.Load().(time.Time)
if ok {
select {
case <-time.After(time.Until(until)):
case <-s.interrupt:
}
}
s.backoffStarted.Store(false)
}()
} }
// Check if we have blacklisted this node. // Check if we have blacklisted this node.
@ -136,53 +167,6 @@ func (s *ServerStatistics) BackoffInfo() (*time.Time, bool) {
return nil, s.blacklisted.Load() return nil, s.blacklisted.Load()
} }
// BackoffIfRequired will block for as long as the current
// backoff requires, if needed. Otherwise it will do nothing.
// Returns the amount of time to backoff for and whether to give up or not.
func (s *ServerStatistics) BackoffIfRequired(backingOff atomic.Bool, interrupt <-chan bool) (time.Duration, bool) {
if started := s.backoffStarted.Load(); !started {
return 0, false
}
// Work out if we should be blacklisting at this point.
count := s.backoffCount.Inc()
if count >= s.statistics.FailuresUntilBlacklist {
// We've exceeded the maximum amount of times we're willing
// to back off, which is probably in the region of hours by
// now. Mark the host as blacklisted and tell the caller to
// give up.
s.blacklisted.Store(true)
if s.statistics.DB != nil {
if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil {
logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName)
}
}
return 0, true
}
// Work out when we should wait until.
duration := s.duration(count)
until := time.Now().Add(duration)
s.backoffUntil.Store(until)
// Notify the destination queue that we're backing off now.
backingOff.Store(true)
defer backingOff.Store(false)
// Work out how long we should be backing off for.
logrus.Warnf("Backing off %q for %s", s.serverName, duration)
// Wait for either an interruption or for the backoff to
// complete.
select {
case <-interrupt:
logrus.Debugf("Interrupting backoff for %q", s.serverName)
case <-time.After(duration):
}
return duration, false
}
// Blacklisted returns true if the server is blacklisted and false // Blacklisted returns true if the server is blacklisted and false
// otherwise. // otherwise.
func (s *ServerStatistics) Blacklisted() bool { func (s *ServerStatistics) Blacklisted() bool {

View file

@ -4,8 +4,6 @@ import (
"math" "math"
"testing" "testing"
"time" "time"
"go.uber.org/atomic"
) )
func TestBackoff(t *testing.T) { func TestBackoff(t *testing.T) {
@ -27,34 +25,30 @@ func TestBackoff(t *testing.T) {
server.Failure() server.Failure()
t.Logf("Backoff counter: %d", server.backoffCount.Load()) t.Logf("Backoff counter: %d", server.backoffCount.Load())
backingOff := atomic.Bool{}
// Now we're going to simulate backing off a few times to see // Now we're going to simulate backing off a few times to see
// what happens. // what happens.
for i := uint32(1); i <= 10; i++ { for i := uint32(1); i <= 10; i++ {
// Interrupt the backoff - it doesn't really matter if it
// completes but we will find out how long the backoff should
// have been.
interrupt := make(chan bool, 1)
close(interrupt)
// Get the duration.
duration, blacklist := server.BackoffIfRequired(backingOff, interrupt)
// Register another failure for good measure. This should have no // Register another failure for good measure. This should have no
// side effects since a backoff is already in progress. If it does // side effects since a backoff is already in progress. If it does
// then we'll fail. // then we'll fail.
until, blacklisted := server.Failure() until, blacklisted := server.Failure()
if time.Until(until) > duration {
t.Fatal("Failure produced unexpected side effect when it shouldn't have") // Get the duration.
} _, blacklist := server.BackoffInfo()
duration := time.Until(until).Round(time.Second)
// Unset the backoff, or otherwise our next call will think that
// there's a backoff in progress and return the same result.
server.cancel()
server.backoffStarted.Store(false)
// Check if we should be blacklisted by now. // Check if we should be blacklisted by now.
if i >= stats.FailuresUntilBlacklist { if i >= stats.FailuresUntilBlacklist {
if !blacklist { if !blacklist {
t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i) t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i)
} else if blacklist != blacklisted { } else if blacklist != blacklisted {
t.Fatalf("BackoffIfRequired and Failure returned different blacklist values") t.Fatalf("BackoffInfo and Failure returned different blacklist values")
} else { } else {
t.Logf("Backoff %d is blacklisted as expected", i) t.Logf("Backoff %d is blacklisted as expected", i)
continue continue

3
go.mod
View file

@ -1,6 +1,7 @@
module github.com/matrix-org/dendrite module github.com/matrix-org/dendrite
require ( require (
github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/Shopify/sarama v1.27.0 github.com/Shopify/sarama v1.27.0
github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect
github.com/gologme/log v1.2.0 github.com/gologme/log v1.2.0
@ -21,7 +22,7 @@ require (
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd
github.com/matrix-org/gomatrixserverlib v0.0.0-20200907151926-38f437f2b2a6 github.com/matrix-org/gomatrixserverlib v0.0.0-20200922152606-4aa1159e672b
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.2 github.com/mattn/go-sqlite3 v1.14.2

6
go.sum
View file

@ -13,6 +13,8 @@ github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0 h1:p3puK8Sl2xK+2Fnn
github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0/go.mod h1:6Lkn+/zJilRMsKmbmG1RPoamiArC6HS73xbwRyp3UyI= github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0/go.mod h1:6Lkn+/zJilRMsKmbmG1RPoamiArC6HS73xbwRyp3UyI=
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
github.com/Kubuxu/go-os-helper v0.0.1/go.mod h1:N8B+I7vPCT80IcP58r50u4+gEEcsZETFUpAzWW2ep1Y= github.com/Kubuxu/go-os-helper v0.0.1/go.mod h1:N8B+I7vPCT80IcP58r50u4+gEEcsZETFUpAzWW2ep1Y=
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc= github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc=
@ -567,8 +569,8 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg=
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200907151926-38f437f2b2a6 h1:43gla6bLt4opWY1mQkAasF/LUCipZl7x2d44TY0wf40= github.com/matrix-org/gomatrixserverlib v0.0.0-20200922152606-4aa1159e672b h1:I8H9ftkT1K/OA2urt/dfXAYpO3pOiMQL5bvoWm4i0RA=
github.com/matrix-org/gomatrixserverlib v0.0.0-20200907151926-38f437f2b2a6/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/gomatrixserverlib v0.0.0-20200922152606-4aa1159e672b/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4=
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=

View file

@ -36,6 +36,9 @@ import (
jaegermetrics "github.com/uber/jaeger-lib/metrics" jaegermetrics "github.com/uber/jaeger-lib/metrics"
) )
// keyIDRegexp defines allowable characters in Key IDs.
var keyIDRegexp = regexp.MustCompile("^ed25519:[a-zA-Z0-9_]+$")
// Version is the current version of the config format. // Version is the current version of the config format.
// This will change whenever we make breaking changes to the config format. // This will change whenever we make breaking changes to the config format.
const Version = 1 const Version = 1
@ -459,6 +462,9 @@ func readKeyPEM(path string, data []byte) (gomatrixserverlib.KeyID, ed25519.Priv
if !strings.HasPrefix(keyID, "ed25519:") { if !strings.HasPrefix(keyID, "ed25519:") {
return "", nil, fmt.Errorf("key ID %q doesn't start with \"ed25519:\" in %q", keyID, path) return "", nil, fmt.Errorf("key ID %q doesn't start with \"ed25519:\" in %q", keyID, path)
} }
if !keyIDRegexp.MatchString(keyID) {
return "", nil, fmt.Errorf("key ID %q in %q contains illegal characters (use a-z, A-Z, 0-9 and _ only)", keyID, path)
}
_, privKey, err := ed25519.GenerateKey(bytes.NewReader(keyBlock.Bytes)) _, privKey, err := ed25519.GenerateKey(bytes.NewReader(keyBlock.Bytes))
if err != nil { if err != nil {
return "", nil, err return "", nil, err

View file

@ -20,7 +20,7 @@ type Global struct {
// An arbitrary string used to uniquely identify the PrivateKey. Must start with the // An arbitrary string used to uniquely identify the PrivateKey. Must start with the
// prefix "ed25519:". // prefix "ed25519:".
KeyID gomatrixserverlib.KeyID `yaml:"key_id"` KeyID gomatrixserverlib.KeyID `yaml:"-"`
// How long a remote server can cache our server key for before requesting it again. // How long a remote server can cache our server key for before requesting it again.
// Increasing this number will reduce the number of requests made by remote servers // Increasing this number will reduce the number of requests made by remote servers

View file

@ -15,10 +15,14 @@
package sqlutil package sqlutil
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"runtime" "runtime"
"strings"
"github.com/matrix-org/util"
) )
// ErrUserExists is returned if a username already exists in the database. // ErrUserExists is returned if a username already exists in the database.
@ -107,3 +111,44 @@ func SQLiteDriverName() string {
} }
return "sqlite3" return "sqlite3"
} }
func minOfInts(a, b int) int {
if a <= b {
return a
}
return b
}
// QueryProvider defines the interface for querys used by RunLimitedVariablesQuery.
type QueryProvider interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}
// SQLite3MaxVariables is the default maximum number of host parameters in a single SQL statement
// SQLlite can handle. See https://www.sqlite.org/limits.html for more information.
const SQLite3MaxVariables = 999
// RunLimitedVariablesQuery split up a query with more variables than the used database can handle in multiple queries.
func RunLimitedVariablesQuery(ctx context.Context, query string, qp QueryProvider, variables []interface{}, limit uint, rowHandler func(*sql.Rows) error) error {
var start int
for start < len(variables) {
n := minOfInts(len(variables)-start, int(limit))
nextQuery := strings.Replace(query, "($1)", QueryVariadic(n), 1)
rows, err := qp.QueryContext(ctx, nextQuery, variables[start:start+n]...)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("QueryContext returned an error")
return err
}
err = rowHandler(rows)
if closeErr := rows.Close(); closeErr != nil {
util.GetLogger(ctx).WithError(closeErr).Error("RunLimitedVariablesQuery: failed to close rows")
return err
}
if err != nil {
util.GetLogger(ctx).WithError(err).Error("RunLimitedVariablesQuery: rowHandler returned error")
return err
}
start = start + n
}
return nil
}

View file

@ -0,0 +1,173 @@
package sqlutil
import (
"context"
"database/sql"
"reflect"
"testing"
sqlmock "github.com/DATA-DOG/go-sqlmock"
)
func TestShouldReturnCorrectAmountOfResulstIfFewerVariablesThanLimit(t *testing.T) {
db, mock, err := sqlmock.New()
assertNoError(t, err, "Failed to make DB")
limit := uint(4)
r := mock.NewRows([]string{"id"}).
AddRow(1).
AddRow(2).
AddRow(3)
mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3\)`).WillReturnRows(r)
// nolint:goconst
q := "SELECT id WHERE id IN ($1)"
v := []int{1, 2, 3}
iKeyIDs := make([]interface{}, len(v))
for i, d := range v {
iKeyIDs[i] = d
}
ctx := context.Background()
var result = make([]int, 0)
err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error {
for rows.Next() {
var id int
err = rows.Scan(&id)
assertNoError(t, err, "rows.Scan returned an error")
result = append(result, id)
}
return nil
})
assertNoError(t, err, "Call returned an error")
if len(result) != len(v) {
t.Fatalf("Result should be 3 long")
}
}
func TestShouldReturnCorrectAmountOfResulstIfEqualVariablesAsLimit(t *testing.T) {
db, mock, err := sqlmock.New()
assertNoError(t, err, "Failed to make DB")
limit := uint(4)
r := mock.NewRows([]string{"id"}).
AddRow(1).
AddRow(2).
AddRow(3).
AddRow(4)
mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3, \$4\)`).WillReturnRows(r)
// nolint:goconst
q := "SELECT id WHERE id IN ($1)"
v := []int{1, 2, 3, 4}
iKeyIDs := make([]interface{}, len(v))
for i, d := range v {
iKeyIDs[i] = d
}
ctx := context.Background()
var result = make([]int, 0)
err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error {
for rows.Next() {
var id int
err = rows.Scan(&id)
assertNoError(t, err, "rows.Scan returned an error")
result = append(result, id)
}
return nil
})
assertNoError(t, err, "Call returned an error")
if len(result) != len(v) {
t.Fatalf("Result should be 4 long")
}
}
func TestShouldReturnCorrectAmountOfResultsIfMoreVariablesThanLimit(t *testing.T) {
db, mock, err := sqlmock.New()
assertNoError(t, err, "Failed to make DB")
limit := uint(4)
r1 := mock.NewRows([]string{"id"}).
AddRow(1).
AddRow(2).
AddRow(3).
AddRow(4)
r2 := mock.NewRows([]string{"id"}).
AddRow(5)
mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3, \$4\)`).WillReturnRows(r1)
mock.ExpectQuery(`SELECT id WHERE id IN \(\$1\)`).WillReturnRows(r2)
// nolint:goconst
q := "SELECT id WHERE id IN ($1)"
v := []int{1, 2, 3, 4, 5}
iKeyIDs := make([]interface{}, len(v))
for i, d := range v {
iKeyIDs[i] = d
}
ctx := context.Background()
var result = make([]int, 0)
err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error {
for rows.Next() {
var id int
err = rows.Scan(&id)
assertNoError(t, err, "rows.Scan returned an error")
result = append(result, id)
}
return nil
})
assertNoError(t, err, "Call returned an error")
if len(result) != len(v) {
t.Fatalf("Result should be 5 long")
}
if !reflect.DeepEqual(v, result) {
t.Fatalf("Result is not as expected: got %v want %v", v, result)
}
}
func TestShouldReturnErrorIfRowsScanReturnsError(t *testing.T) {
db, mock, err := sqlmock.New()
assertNoError(t, err, "Failed to make DB")
limit := uint(4)
// adding a string ID should result in rows.Scan returning an error
r := mock.NewRows([]string{"id"}).
AddRow("hej").
AddRow(2).
AddRow(3)
mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3\)`).WillReturnRows(r)
// nolint:goconst
q := "SELECT id WHERE id IN ($1)"
v := []int{-1, -2, 3}
iKeyIDs := make([]interface{}, len(v))
for i, d := range v {
iKeyIDs[i] = d
}
ctx := context.Background()
var result = make([]uint, 0)
err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error {
for rows.Next() {
var id uint
err = rows.Scan(&id)
if err != nil {
return err
}
result = append(result, id)
}
return nil
})
if err == nil {
t.Fatalf("Call did not return an error")
}
}
func assertNoError(t *testing.T, err error, msg string) {
t.Helper()
if err == nil {
return
}
t.Fatalf(msg)
}

View file

@ -25,6 +25,7 @@ import (
"math/big" "math/big"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"time" "time"
"github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/config"
@ -146,10 +147,14 @@ func NewMatrixKey(matrixKeyPath string) (err error) {
err = keyOut.Close() err = keyOut.Close()
})() })()
keyID := base64.RawURLEncoding.EncodeToString(data[:])
keyID = strings.ReplaceAll(keyID, "-", "")
keyID = strings.ReplaceAll(keyID, "_", "")
err = pem.Encode(keyOut, &pem.Block{ err = pem.Encode(keyOut, &pem.Block{
Type: "MATRIX PRIVATE KEY", Type: "MATRIX PRIVATE KEY",
Headers: map[string]string{ Headers: map[string]string{
"Key-ID": "ed25519:" + base64.RawStdEncoding.EncodeToString(data[:3]), "Key-ID": fmt.Sprintf("ed25519:%s", keyID[:6]),
}, },
Bytes: data[3:], Bytes: data[3:],
}) })

View file

@ -16,7 +16,7 @@ type RoomserverInternalAPI interface {
ctx context.Context, ctx context.Context,
request *InputRoomEventsRequest, request *InputRoomEventsRequest,
response *InputRoomEventsResponse, response *InputRoomEventsResponse,
) error )
PerformInvite( PerformInvite(
ctx context.Context, ctx context.Context,

View file

@ -23,10 +23,9 @@ func (t *RoomserverInternalAPITrace) InputRoomEvents(
ctx context.Context, ctx context.Context,
req *InputRoomEventsRequest, req *InputRoomEventsRequest,
res *InputRoomEventsResponse, res *InputRoomEventsResponse,
) error { ) {
err := t.Impl.InputRoomEvents(ctx, req, res) t.Impl.InputRoomEvents(ctx, req, res)
util.GetLogger(ctx).WithError(err).Infof("InputRoomEvents req=%+v res=%+v", js(req), js(res)) util.GetLogger(ctx).Infof("InputRoomEvents req=%+v res=%+v", js(req), js(res))
return err
} }
func (t *RoomserverInternalAPITrace) PerformInvite( func (t *RoomserverInternalAPITrace) PerformInvite(

View file

@ -16,6 +16,8 @@
package api package api
import ( import (
"fmt"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -33,6 +35,10 @@ const (
// KindBackfill event extend the contiguous graph going backwards. // KindBackfill event extend the contiguous graph going backwards.
// They always have state. // They always have state.
KindBackfill = 3 KindBackfill = 3
// KindRewrite events are used when rewriting the head of the room
// graph with entirely new state. The output events generated will
// be state events rather than timeline events.
KindRewrite = 4
) )
// DoNotSendToOtherServers tells us not to send the event to other matrix // DoNotSendToOtherServers tells us not to send the event to other matrix
@ -83,4 +89,18 @@ type InputRoomEventsRequest struct {
// InputRoomEventsResponse is a response to InputRoomEvents // InputRoomEventsResponse is a response to InputRoomEvents
type InputRoomEventsResponse struct { type InputRoomEventsResponse struct {
ErrMsg string // set if there was any error
NotAllowed bool // true if an event in the input was not allowed.
}
func (r *InputRoomEventsResponse) Err() error {
if r.ErrMsg == "" {
return nil
}
if r.NotAllowed {
return &gomatrixserverlib.NotAllowed{
Message: r.ErrMsg,
}
}
return fmt.Errorf("InputRoomEventsResponse: %s", r.ErrMsg)
} }

View file

@ -68,6 +68,17 @@ type OutputEvent struct {
NewPeek *OutputNewPeek `json:"new_peek,omitempty"` NewPeek *OutputNewPeek `json:"new_peek,omitempty"`
} }
// Type of the OutputNewRoomEvent.
type OutputRoomEventType int
const (
// The event is a timeline event and likely just happened.
OutputRoomTimeline OutputRoomEventType = iota
// The event is a state event and quite possibly happened in the past.
OutputRoomState
)
// An OutputNewRoomEvent is written when the roomserver receives a new event. // An OutputNewRoomEvent is written when the roomserver receives a new event.
// It contains the full matrix room event and enough information for a // It contains the full matrix room event and enough information for a
// consumer to construct the current state of the room and the state before the // consumer to construct the current state of the room and the state before the
@ -80,6 +91,9 @@ type OutputEvent struct {
type OutputNewRoomEvent struct { type OutputNewRoomEvent struct {
// The Event. // The Event.
Event gomatrixserverlib.HeaderedEvent `json:"event"` Event gomatrixserverlib.HeaderedEvent `json:"event"`
// Does the event completely rewrite the room state? If so, then AddsStateEventIDs
// will contain the entire room state.
RewritesState bool `json:"rewrites_state"`
// The latest events in the room after this event. // The latest events in the room after this event.
// This can be used to set the prev events for new events in the room. // This can be used to set the prev events for new events in the room.
// This also can be used to get the full current state after this event. // This also can be used to get the full current state after this event.

View file

@ -80,13 +80,107 @@ func SendEventWithState(
return SendInputRoomEvents(ctx, rsAPI, ires) return SendInputRoomEvents(ctx, rsAPI, ires)
} }
// SendEventWithRewrite writes an event with KindNew to the roomserver along
// with a number of rewrite and outlier events for state and auth events
// respectively.
func SendEventWithRewrite(
ctx context.Context, rsAPI RoomserverInternalAPI, state *gomatrixserverlib.RespState,
event gomatrixserverlib.HeaderedEvent, haveEventIDs map[string]bool,
) error {
isCurrentState := map[string]struct{}{}
for _, se := range state.StateEvents {
isCurrentState[se.EventID()] = struct{}{}
}
authAndStateEvents, err := state.Events()
if err != nil {
return err
}
var ires []InputRoomEvent
var stateIDs []string
// This function generates three things:
// A - A set of "rewrite" events, which will form the newly rewritten
// state before the event, which includes every rewrite event that
// came before it in its state
// B - A set of "outlier" events, which are auth events but not part
// of the rewritten state
// C - A "new" event, which include all of the rewrite events in its
// state
for _, authOrStateEvent := range authAndStateEvents {
if authOrStateEvent.StateKey() == nil {
continue
}
if haveEventIDs[authOrStateEvent.EventID()] {
continue
}
if event.StateKey() == nil {
continue
}
// We will handle an event as if it's an outlier if one of the
// following conditions is true:
storeAsOutlier := false
if _, ok := isCurrentState[authOrStateEvent.EventID()]; !ok {
// The event is an auth event and isn't a part of the state set.
// We'll send it as an outlier because we need it to be stored
// in case something is referring to it as an auth event.
storeAsOutlier = true
}
if storeAsOutlier {
ires = append(ires, InputRoomEvent{
Kind: KindOutlier,
Event: authOrStateEvent.Headered(event.RoomVersion),
AuthEventIDs: authOrStateEvent.AuthEventIDs(),
})
continue
}
// If the event isn't an outlier then we'll instead send it as a
// rewrite event, so that it'll form part of the rewritten state.
// These events will go through the membership and latest event
// updaters and we will generate output events, but they will be
// flagged as non-current (i.e. didn't just happen) events.
// Each of these rewrite events includes all of the rewrite events
// that came before in their StateEventIDs.
ires = append(ires, InputRoomEvent{
Kind: KindRewrite,
Event: authOrStateEvent.Headered(event.RoomVersion),
AuthEventIDs: authOrStateEvent.AuthEventIDs(),
HasState: true,
StateEventIDs: stateIDs,
})
// Add the event ID into the StateEventIDs of all subsequent
// rewrite events, and the new event.
stateIDs = append(stateIDs, authOrStateEvent.EventID())
}
// Send the final event as a new event, which will generate
// a timeline output event for it. All of the rewrite events
// that came before will be sent as StateEventIDs, forming a
// new clean state before the event.
ires = append(ires, InputRoomEvent{
Kind: KindNew,
Event: event,
AuthEventIDs: event.AuthEventIDs(),
HasState: true,
StateEventIDs: stateIDs,
})
return SendInputRoomEvents(ctx, rsAPI, ires)
}
// SendInputRoomEvents to the roomserver. // SendInputRoomEvents to the roomserver.
func SendInputRoomEvents( func SendInputRoomEvents(
ctx context.Context, rsAPI RoomserverInternalAPI, ires []InputRoomEvent, ctx context.Context, rsAPI RoomserverInternalAPI, ires []InputRoomEvent,
) error { ) error {
request := InputRoomEventsRequest{InputRoomEvents: ires} request := InputRoomEventsRequest{InputRoomEvents: ires}
var response InputRoomEventsResponse var response InputRoomEventsResponse
return rsAPI.InputRoomEvents(ctx, &request, &response) rsAPI.InputRoomEvents(ctx, &request, &response)
return response.Err()
} }
// SendInvite event to the roomserver. // SendInvite event to the roomserver.

View file

@ -271,5 +271,6 @@ func (r *RoomserverInternalAPI) sendUpdatedAliasesEvent(
var inputRes api.InputRoomEventsResponse var inputRes api.InputRoomEventsResponse
// Send the request // Send the request
return r.InputRoomEvents(ctx, &inputReq, &inputRes) r.InputRoomEvents(ctx, &inputReq, &inputRes)
return inputRes.Err()
} }

View file

@ -16,13 +16,78 @@ package helpers
import ( import (
"context" "context"
"fmt"
"sort" "sort"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
// CheckForSoftFail returns true if the event should be soft-failed
// and false otherwise. The return error value should be checked before
// the soft-fail bool.
func CheckForSoftFail(
ctx context.Context,
db storage.Database,
event gomatrixserverlib.HeaderedEvent,
stateEventIDs []string,
) (bool, error) {
rewritesState := len(stateEventIDs) > 1
var authStateEntries []types.StateEntry
var err error
if rewritesState {
authStateEntries, err = db.StateEntriesForEventIDs(ctx, stateEventIDs)
if err != nil {
return true, fmt.Errorf("StateEntriesForEventIDs failed: %w", err)
}
} else {
// Work out if the room exists.
var roomInfo *types.RoomInfo
roomInfo, err = db.RoomInfo(ctx, event.RoomID())
if err != nil {
return false, fmt.Errorf("db.RoomNID: %w", err)
}
if roomInfo == nil || roomInfo.IsStub {
return false, nil
}
// Then get the state entries for the current state snapshot.
// We'll use this to check if the event is allowed right now.
roomState := state.NewStateResolution(db, *roomInfo)
authStateEntries, err = roomState.LoadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID)
if err != nil {
return true, fmt.Errorf("roomState.LoadStateAtSnapshot: %w", err)
}
}
// As a special case, it's possible that the room will have no
// state because we haven't received a m.room.create event yet.
// If we're now processing the first create event then never
// soft-fail it.
if len(authStateEntries) == 0 && event.Type() == gomatrixserverlib.MRoomCreate {
return false, nil
}
// Work out which of the state events we actually need.
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()})
// Load the actual auth events from the database.
authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
if err != nil {
return true, fmt.Errorf("loadAuthEvents: %w", err)
}
// Check if the event is allowed.
if err = gomatrixserverlib.Allowed(event.Event, &authEvents); err != nil {
// return true, nil
return true, fmt.Errorf("gomatrixserverlib.Allowed: %w", err)
}
return false, nil
}
// CheckAuthEvents checks that the event passes authentication checks // CheckAuthEvents checks that the event passes authentication checks
// Returns the numeric IDs for the auth events. // Returns the numeric IDs for the auth events.
func CheckAuthEvents( func CheckAuthEvents(
@ -36,7 +101,7 @@ func CheckAuthEvents(
if err != nil { if err != nil {
return nil, err return nil, err
} }
// TODO: check for duplicate state keys here. authStateEntries = types.DeduplicateStateEntries(authStateEntries)
// Work out which of the state events we actually need. // Work out which of the state events we actually need.
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()}) stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()})

View file

@ -110,7 +110,7 @@ func (r *Inputer) InputRoomEvents(
ctx context.Context, ctx context.Context,
request *api.InputRoomEventsRequest, request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse, response *api.InputRoomEventsResponse,
) error { ) {
// Create a wait group. Each task that we dispatch will call Done on // Create a wait group. Each task that we dispatch will call Done on
// this wait group so that we know when all of our events have been // this wait group so that we know when all of our events have been
// processed. // processed.
@ -156,8 +156,10 @@ func (r *Inputer) InputRoomEvents(
// that back to the caller. // that back to the caller.
for _, task := range tasks { for _, task := range tasks {
if task.err != nil { if task.err != nil {
return task.err response.ErrMsg = task.err.Error()
_, rejected := task.err.(*gomatrixserverlib.NotAllowed)
response.NotAllowed = rejected
return
} }
} }
return nil
} }

View file

@ -46,10 +46,25 @@ func (r *Inputer) processRoomEvent(
// Check that the event passes authentication checks and work out // Check that the event passes authentication checks and work out
// the numeric IDs for the auth events. // the numeric IDs for the auth events.
authEventNIDs, err := helpers.CheckAuthEvents(ctx, r.DB, headered, input.AuthEventIDs) isRejected := false
if err != nil { authEventNIDs, rejectionErr := helpers.CheckAuthEvents(ctx, r.DB, headered, input.AuthEventIDs)
logrus.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event") if rejectionErr != nil {
return logrus.WithError(rejectionErr).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event, rejecting event")
isRejected = true
}
var softfail bool
if input.Kind == api.KindBackfill || input.Kind == api.KindNew {
// Check that the event passes authentication checks based on the
// current room state.
softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs)
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": event.EventID(),
"type": event.Type(),
"room": event.RoomID(),
}).WithError(err).Info("Error authing soft-failed event")
}
} }
// If we don't have a transaction ID then get one. // If we don't have a transaction ID then get one.
@ -65,12 +80,13 @@ func (r *Inputer) processRoomEvent(
} }
// Store the event. // Store the event.
_, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs) _, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs, isRejected)
if err != nil { if err != nil {
return "", fmt.Errorf("r.DB.StoreEvent: %w", err) return "", fmt.Errorf("r.DB.StoreEvent: %w", err)
} }
// if storing this event results in it being redacted then do so. // if storing this event results in it being redacted then do so.
if redactedEventID == event.EventID() { if !isRejected && redactedEventID == event.EventID() {
r, rerr := eventutil.RedactEvent(redactionEvent, &event) r, rerr := eventutil.RedactEvent(redactionEvent, &event)
if rerr != nil { if rerr != nil {
return "", fmt.Errorf("eventutil.RedactEvent: %w", rerr) return "", fmt.Errorf("eventutil.RedactEvent: %w", rerr)
@ -86,7 +102,8 @@ func (r *Inputer) processRoomEvent(
"event_id": event.EventID(), "event_id": event.EventID(),
"type": event.Type(), "type": event.Type(),
"room": event.RoomID(), "room": event.RoomID(),
}).Info("Stored outlier") "sender": event.Sender(),
}).Debug("Stored outlier")
return event.EventID(), nil return event.EventID(), nil
} }
@ -101,12 +118,33 @@ func (r *Inputer) processRoomEvent(
if stateAtEvent.BeforeStateSnapshotNID == 0 { if stateAtEvent.BeforeStateSnapshotNID == 0 {
// We haven't calculated a state for this event yet. // We haven't calculated a state for this event yet.
// Lets calculate one. // Lets calculate one.
err = r.calculateAndSetState(ctx, input, *roomInfo, &stateAtEvent, event) err = r.calculateAndSetState(ctx, input, *roomInfo, &stateAtEvent, event, isRejected)
if err != nil { if err != nil {
return "", fmt.Errorf("r.calculateAndSetState: %w", err) return "", fmt.Errorf("r.calculateAndSetState: %w", err)
} }
} }
// We stop here if the event is rejected: We've stored it but won't update forward extremities or notify anyone about it.
if isRejected || softfail {
logrus.WithFields(logrus.Fields{
"event_id": event.EventID(),
"type": event.Type(),
"room": event.RoomID(),
"soft_fail": softfail,
"sender": event.Sender(),
}).Debug("Stored rejected event")
return event.EventID(), rejectionErr
}
if input.Kind == api.KindRewrite {
logrus.WithFields(logrus.Fields{
"event_id": event.EventID(),
"type": event.Type(),
"room": event.RoomID(),
}).Debug("Stored rewrite")
return event.EventID(), nil
}
if err = r.updateLatestEvents( if err = r.updateLatestEvents(
ctx, // context ctx, // context
roomInfo, // room info for the room being updated roomInfo, // room info for the room being updated
@ -114,6 +152,7 @@ func (r *Inputer) processRoomEvent(
event, // event event, // event
input.SendAsServer, // send as server input.SendAsServer, // send as server
input.TransactionID, // transaction ID input.TransactionID, // transaction ID
input.HasState, // rewrites state?
); err != nil { ); err != nil {
return "", fmt.Errorf("r.updateLatestEvents: %w", err) return "", fmt.Errorf("r.updateLatestEvents: %w", err)
} }
@ -147,11 +186,12 @@ func (r *Inputer) calculateAndSetState(
roomInfo types.RoomInfo, roomInfo types.RoomInfo,
stateAtEvent *types.StateAtEvent, stateAtEvent *types.StateAtEvent,
event gomatrixserverlib.Event, event gomatrixserverlib.Event,
isRejected bool,
) error { ) error {
var err error var err error
roomState := state.NewStateResolution(r.DB, roomInfo) roomState := state.NewStateResolution(r.DB, roomInfo)
if input.HasState { if input.HasState && !isRejected {
// Check here if we think we're in the room already. // Check here if we think we're in the room already.
stateAtEvent.Overwrite = true stateAtEvent.Overwrite = true
var joinEventNIDs []types.EventNID var joinEventNIDs []types.EventNID
@ -167,19 +207,25 @@ func (r *Inputer) calculateAndSetState(
// Check that those state events are in the database and store the state. // Check that those state events are in the database and store the state.
var entries []types.StateEntry var entries []types.StateEntry
if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil { if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
return err return fmt.Errorf("r.DB.StateEntriesForEventIDs: %w", err)
} }
entries = types.DeduplicateStateEntries(entries)
if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil {
return err return fmt.Errorf("r.DB.AddState: %w", err)
} }
} else { } else {
stateAtEvent.Overwrite = false stateAtEvent.Overwrite = false
// We haven't been told what the state at the event is so we need to calculate it from the prev_events // We haven't been told what the state at the event is so we need to calculate it from the prev_events
if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event); err != nil { if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event, isRejected); err != nil {
return err return fmt.Errorf("roomState.CalculateAndStoreStateBeforeEvent: %w", err)
} }
} }
return r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
err = r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
if err != nil {
return fmt.Errorf("r.DB.SetState: %w", err)
}
return nil
} }

View file

@ -54,6 +54,7 @@ func (r *Inputer) updateLatestEvents(
event gomatrixserverlib.Event, event gomatrixserverlib.Event,
sendAsServer string, sendAsServer string,
transactionID *api.TransactionID, transactionID *api.TransactionID,
rewritesState bool,
) (err error) { ) (err error) {
updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo) updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo)
if err != nil { if err != nil {
@ -71,6 +72,7 @@ func (r *Inputer) updateLatestEvents(
event: event, event: event,
sendAsServer: sendAsServer, sendAsServer: sendAsServer,
transactionID: transactionID, transactionID: transactionID,
rewritesState: rewritesState,
} }
if err = u.doUpdateLatestEvents(); err != nil { if err = u.doUpdateLatestEvents(); err != nil {
@ -93,6 +95,7 @@ type latestEventsUpdater struct {
stateAtEvent types.StateAtEvent stateAtEvent types.StateAtEvent
event gomatrixserverlib.Event event gomatrixserverlib.Event
transactionID *api.TransactionID transactionID *api.TransactionID
rewritesState bool
// Which server to send this event as. // Which server to send this event as.
sendAsServer string sendAsServer string
// The eventID of the event that was processed before this one. // The eventID of the event that was processed before this one.
@ -178,7 +181,8 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
return fmt.Errorf("u.api.updateMemberships: %w", err) return fmt.Errorf("u.api.updateMemberships: %w", err)
} }
update, err := u.makeOutputNewRoomEvent() var update *api.OutputEvent
update, err = u.makeOutputNewRoomEvent()
if err != nil { if err != nil {
return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err) return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err)
} }
@ -305,6 +309,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error)
ore := api.OutputNewRoomEvent{ ore := api.OutputNewRoomEvent{
Event: u.event.Headered(u.roomInfo.RoomVersion), Event: u.event.Headered(u.roomInfo.RoomVersion),
RewritesState: u.rewritesState,
LastSentEventID: u.lastEventIDSent, LastSentEventID: u.lastEventIDSent,
LatestEventIDs: latestEventIDs, LatestEventIDs: latestEventIDs,
TransactionID: u.transactionID, TransactionID: u.transactionID,
@ -337,6 +342,11 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error)
return nil, fmt.Errorf("failed to load add_state_events from db: %w", err) return nil, fmt.Errorf("failed to load add_state_events from db: %w", err)
} }
} }
// State is rewritten if the input room event HasState and we actually produced a delta on state events.
// Without this check, /get_missing_events which produce events with associated (but not complete) state
// will incorrectly purge the room and set it to no state. TODO: This is likely flakey, as if /gme produced
// a state conflict res which just so happens to include 2+ events we might purge the room state downstream.
ore.RewritesState = len(ore.AddsStateEventIDs) > 1
return &api.OutputEvent{ return &api.OutputEvent{
Type: api.OutputTypeNewRoomEvent, Type: api.OutputTypeNewRoomEvent,

View file

@ -547,7 +547,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []gomatrixse
var stateAtEvent types.StateAtEvent var stateAtEvent types.StateAtEvent
var redactedEventID string var redactedEventID string
var redactionEvent *gomatrixserverlib.Event var redactionEvent *gomatrixserverlib.Event
roomNID, stateAtEvent, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids) roomNID, stateAtEvent, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids, false)
if err != nil { if err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event") logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event")
continue continue

View file

@ -183,7 +183,8 @@ func (r *Inviter) PerformInvite(
}, },
} }
inputRes := &api.InputRoomEventsResponse{} inputRes := &api.InputRoomEventsResponse{}
if err = r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes); err != nil { r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes)
if err = inputRes.Err(); err != nil {
return nil, fmt.Errorf("r.InputRoomEvents: %w", err) return nil, fmt.Errorf("r.InputRoomEvents: %w", err)
} }
} else { } else {

View file

@ -183,33 +183,33 @@ func (r *Joiner) performJoinRoomByID(
return "", fmt.Errorf("eb.SetContent: %w", err) return "", fmt.Errorf("eb.SetContent: %w", err)
} }
// First work out if this is in response to an existing invite // Force a federated join if we aren't in the room and we've been
// from a federated server. If it is then we avoid the situation // given some server names to try joining by.
// where we might think we know about a room in the following
// section but don't know the latest state as all of our users
// have left.
serverInRoom, _ := helpers.IsServerCurrentlyInRoom(ctx, r.DB, r.ServerName, req.RoomIDOrAlias) serverInRoom, _ := helpers.IsServerCurrentlyInRoom(ctx, r.DB, r.ServerName, req.RoomIDOrAlias)
forceFederatedJoin := len(req.ServerNames) > 0 && !serverInRoom
// Force a federated join if we're dealing with a pending invite
// and we aren't in the room.
isInvitePending, inviteSender, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, req.UserID) isInvitePending, inviteSender, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, req.UserID)
if err == nil && isInvitePending && !serverInRoom { if err == nil && isInvitePending {
// Check if there's an invite pending.
_, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender) _, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender)
if ierr != nil { if ierr != nil {
return "", fmt.Errorf("gomatrixserverlib.SplitID: %w", err) return "", fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
} }
// Check that the domain isn't ours. If it's local then we don't // If we were invited by someone from another server then we can
// need to do anything as our own copy of the room state will be // assume they are in the room so we can join via them.
// up-to-date.
if inviterDomain != r.Cfg.Matrix.ServerName { if inviterDomain != r.Cfg.Matrix.ServerName {
// Add the server of the person who invited us to the server list,
// as they should be a fairly good bet.
req.ServerNames = append(req.ServerNames, inviterDomain) req.ServerNames = append(req.ServerNames, inviterDomain)
forceFederatedJoin = true
// Perform a federated room join.
return req.RoomIDOrAlias, r.performFederatedJoinRoomByID(ctx, req)
} }
} }
// If we should do a forced federated join then do that.
if forceFederatedJoin {
return req.RoomIDOrAlias, r.performFederatedJoinRoomByID(ctx, req)
}
// Try to construct an actual join event from the template. // Try to construct an actual join event from the template.
// If this succeeds then it is a sign that the room already exists // If this succeeds then it is a sign that the room already exists
// locally on the homeserver. // locally on the homeserver.
@ -247,7 +247,8 @@ func (r *Joiner) performJoinRoomByID(
}, },
} }
inputRes := api.InputRoomEventsResponse{} inputRes := api.InputRoomEventsResponse{}
if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes)
if err = inputRes.Err(); err != nil {
var notAllowed *gomatrixserverlib.NotAllowed var notAllowed *gomatrixserverlib.NotAllowed
if errors.As(err, &notAllowed) { if errors.As(err, &notAllowed) {
return "", &api.PerformError{ return "", &api.PerformError{

View file

@ -139,7 +139,8 @@ func (r *Leaver) performLeaveRoomByID(
}, },
} }
inputRes := api.InputRoomEventsResponse{} inputRes := api.InputRoomEventsResponse{}
if err = r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil { r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes)
if err = inputRes.Err(); err != nil {
return nil, fmt.Errorf("r.InputRoomEvents: %w", err) return nil, fmt.Errorf("r.InputRoomEvents: %w", err)
} }

View file

@ -70,6 +70,7 @@ func (r *Queryer) QueryStateAfterEvents(
if err != nil { if err != nil {
switch err.(type) { switch err.(type) {
case types.MissingEventError: case types.MissingEventError:
util.GetLogger(ctx).Errorf("QueryStateAfterEvents: MissingEventError: %s", err)
return nil return nil
default: default:
return err return err

View file

@ -149,12 +149,15 @@ func (h *httpRoomserverInternalAPI) InputRoomEvents(
ctx context.Context, ctx context.Context,
request *api.InputRoomEventsRequest, request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse, response *api.InputRoomEventsResponse,
) error { ) {
span, ctx := opentracing.StartSpanFromContext(ctx, "InputRoomEvents") span, ctx := opentracing.StartSpanFromContext(ctx, "InputRoomEvents")
defer span.Finish() defer span.Finish()
apiURL := h.roomserverURL + RoomserverInputRoomEventsPath apiURL := h.roomserverURL + RoomserverInputRoomEventsPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.ErrMsg = err.Error()
}
} }
func (h *httpRoomserverInternalAPI) PerformInvite( func (h *httpRoomserverInternalAPI) PerformInvite(

View file

@ -20,9 +20,7 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) {
if err := json.NewDecoder(req.Body).Decode(&request); err != nil { if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error()) return util.MessageResponse(http.StatusBadRequest, err.Error())
} }
if err := r.InputRoomEvents(req.Context(), &request, &response); err != nil { r.InputRoomEvents(req.Context(), &request, &response)
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )

View file

@ -1,12 +1,15 @@
package roomserver package roomserver
import ( import (
"bytes"
"context" "context"
"crypto/ed25519"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
"testing" "testing"
"time"
"github.com/Shopify/sarama" "github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
@ -80,7 +83,65 @@ func deleteDatabase() {
} }
} }
func mustLoadEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) []gomatrixserverlib.HeaderedEvent { type fledglingEvent struct {
Type string
StateKey *string
Content interface{}
Sender string
RoomID string
}
func mustCreateEvents(t *testing.T, roomVer gomatrixserverlib.RoomVersion, events []fledglingEvent) (result []gomatrixserverlib.HeaderedEvent) {
t.Helper()
depth := int64(1)
seed := make([]byte, ed25519.SeedSize) // zero seed
key := ed25519.NewKeyFromSeed(seed)
var prevs []string
roomState := make(map[gomatrixserverlib.StateKeyTuple]string) // state -> event ID
for _, ev := range events {
eb := gomatrixserverlib.EventBuilder{
Sender: ev.Sender,
Depth: depth,
Type: ev.Type,
StateKey: ev.StateKey,
RoomID: ev.RoomID,
PrevEvents: prevs,
}
err := eb.SetContent(ev.Content)
if err != nil {
t.Fatalf("mustCreateEvent: failed to marshal event content %+v", ev.Content)
}
stateNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(&eb)
if err != nil {
t.Fatalf("mustCreateEvent: failed to work out auth_events : %s", err)
}
var authEvents []string
for _, tuple := range stateNeeded.Tuples() {
eventID := roomState[tuple]
if eventID != "" {
authEvents = append(authEvents, eventID)
}
}
eb.AuthEvents = authEvents
signedEvent, err := eb.Build(time.Now(), testOrigin, "ed25519:test", key, roomVer)
if err != nil {
t.Fatalf("mustCreateEvent: failed to sign event: %s", err)
}
depth++
prevs = []string{signedEvent.EventID()}
if ev.StateKey != nil {
roomState[gomatrixserverlib.StateKeyTuple{
EventType: ev.Type,
StateKey: *ev.StateKey,
}] = signedEvent.EventID()
}
result = append(result, signedEvent.Headered(roomVer))
}
return
}
func mustLoadRawEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) []gomatrixserverlib.HeaderedEvent {
t.Helper()
hs := make([]gomatrixserverlib.HeaderedEvent, len(events)) hs := make([]gomatrixserverlib.HeaderedEvent, len(events))
for i := range events { for i := range events {
e, err := gomatrixserverlib.NewEventFromTrustedJSON(events[i], false, ver) e, err := gomatrixserverlib.NewEventFromTrustedJSON(events[i], false, ver)
@ -93,7 +154,8 @@ func mustLoadEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []js
return hs return hs
} }
func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) (api.RoomserverInternalAPI, *dummyProducer, []gomatrixserverlib.HeaderedEvent) { func mustCreateRoomserverAPI(t *testing.T) (api.RoomserverInternalAPI, *dummyProducer) {
t.Helper()
cfg := &config.Dendrite{} cfg := &config.Dendrite{}
cfg.Defaults() cfg.Defaults()
cfg.Global.ServerName = testOrigin cfg.Global.ServerName = testOrigin
@ -102,7 +164,7 @@ func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []js
dp := &dummyProducer{ dp := &dummyProducer{
topic: cfg.Global.Kafka.TopicFor(config.TopicOutputRoomEvent), topic: cfg.Global.Kafka.TopicFor(config.TopicOutputRoomEvent),
} }
cache, err := caching.NewInMemoryLRUCache(true) cache, err := caching.NewInMemoryLRUCache(false)
if err != nil { if err != nil {
t.Fatalf("failed to make caches: %s", err) t.Fatalf("failed to make caches: %s", err)
} }
@ -112,9 +174,14 @@ func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []js
Cfg: cfg, Cfg: cfg,
} }
rsAPI := NewInternalAPI(base, &test.NopJSONVerifier{}) return NewInternalAPI(base, &test.NopJSONVerifier{}), dp
hevents := mustLoadEvents(t, ver, events) }
if err = api.SendEvents(ctx, rsAPI, hevents, testOrigin, nil); err != nil {
func mustSendEvents(t *testing.T, ver gomatrixserverlib.RoomVersion, events []json.RawMessage) (api.RoomserverInternalAPI, *dummyProducer, []gomatrixserverlib.HeaderedEvent) {
t.Helper()
rsAPI, dp := mustCreateRoomserverAPI(t)
hevents := mustLoadRawEvents(t, ver, events)
if err := api.SendEvents(ctx, rsAPI, hevents, testOrigin, nil); err != nil {
t.Errorf("failed to SendEvents: %s", err) t.Errorf("failed to SendEvents: %s", err)
} }
return rsAPI, dp, hevents return rsAPI, dp, hevents
@ -170,3 +237,163 @@ func TestOutputRedactedEvent(t *testing.T) {
} }
} }
} }
// This tests that rewriting state via KindRewrite works correctly.
// This creates a small room with a create/join/name state, then replays it
// with a new room name. We expect the output events to contain the original events,
// followed by a single OutputNewRoomEvent with RewritesState set to true with the
// rewritten state events (with the 2nd room name).
func TestOutputRewritesState(t *testing.T) {
roomID := "!foo:" + string(testOrigin)
alice := "@alice:" + string(testOrigin)
emptyKey := ""
originalEvents := mustCreateEvents(t, gomatrixserverlib.RoomVersionV6, []fledglingEvent{
{
RoomID: roomID,
Sender: alice,
Content: map[string]interface{}{
"creator": alice,
"room_version": "6",
},
StateKey: &emptyKey,
Type: gomatrixserverlib.MRoomCreate,
},
{
RoomID: roomID,
Sender: alice,
Content: map[string]interface{}{
"membership": "join",
},
StateKey: &alice,
Type: gomatrixserverlib.MRoomMember,
},
{
RoomID: roomID,
Sender: alice,
Content: map[string]interface{}{
"body": "hello world",
},
StateKey: nil,
Type: "m.room.message",
},
{
RoomID: roomID,
Sender: alice,
Content: map[string]interface{}{
"name": "Room Name",
},
StateKey: &emptyKey,
Type: "m.room.name",
},
})
rewriteEvents := mustCreateEvents(t, gomatrixserverlib.RoomVersionV6, []fledglingEvent{
{
RoomID: roomID,
Sender: alice,
Content: map[string]interface{}{
"creator": alice,
},
StateKey: &emptyKey,
Type: gomatrixserverlib.MRoomCreate,
},
{
RoomID: roomID,
Sender: alice,
Content: map[string]interface{}{
"membership": "join",
},
StateKey: &alice,
Type: gomatrixserverlib.MRoomMember,
},
{
RoomID: roomID,
Sender: alice,
Content: map[string]interface{}{
"name": "Room Name 2",
},
StateKey: &emptyKey,
Type: "m.room.name",
},
{
RoomID: roomID,
Sender: alice,
Content: map[string]interface{}{
"body": "hello world 2",
},
StateKey: nil,
Type: "m.room.message",
},
})
deleteDatabase()
rsAPI, producer := mustCreateRoomserverAPI(t)
defer deleteDatabase()
err := api.SendEvents(context.Background(), rsAPI, originalEvents, testOrigin, nil)
if err != nil {
t.Fatalf("failed to send original events: %s", err)
}
// assert we got them produced, this is just a sanity check and isn't the intention of this test
if len(producer.producedMessages) != len(originalEvents) {
t.Fatalf("SendEvents didn't result in same number of produced output events: got %d want %d", len(producer.producedMessages), len(originalEvents))
}
producer.producedMessages = nil // we aren't actually interested in these events, just the rewrite ones
var inputEvents []api.InputRoomEvent
// slowly build up the state IDs again, we're basically telling the roomserver what to store as a snapshot
var stateIDs []string
// skip the last event, we'll use this to tie together the rewrite as the KindNew event
for i := 0; i < len(rewriteEvents)-1; i++ {
ev := rewriteEvents[i]
inputEvents = append(inputEvents, api.InputRoomEvent{
Kind: api.KindRewrite,
Event: ev,
AuthEventIDs: ev.AuthEventIDs(),
HasState: true,
StateEventIDs: stateIDs,
})
if ev.StateKey() != nil {
stateIDs = append(stateIDs, ev.EventID())
}
}
lastEv := rewriteEvents[len(rewriteEvents)-1]
inputEvents = append(inputEvents, api.InputRoomEvent{
Kind: api.KindNew,
Event: lastEv,
AuthEventIDs: lastEv.AuthEventIDs(),
HasState: true,
StateEventIDs: stateIDs,
})
if err := api.SendInputRoomEvents(context.Background(), rsAPI, inputEvents); err != nil {
t.Fatalf("SendInputRoomEvents returned error for rewrite events: %s", err)
}
// we should just have one output event with the entire state of the room in it
if len(producer.producedMessages) != 1 {
t.Fatalf("Rewritten events got output, want only 1 got %d", len(producer.producedMessages))
}
outputEvent := producer.producedMessages[0]
if !outputEvent.NewRoomEvent.RewritesState {
t.Errorf("RewritesState flag not set on output event")
}
if !reflect.DeepEqual(stateIDs, outputEvent.NewRoomEvent.AddsStateEventIDs) {
t.Errorf("Output event is missing room state event IDs, got %v want %v", outputEvent.NewRoomEvent.AddsStateEventIDs, stateIDs)
}
if !bytes.Equal(outputEvent.NewRoomEvent.Event.JSON(), lastEv.JSON()) {
t.Errorf(
"Output event isn't the latest KindNew event:\ngot %s\nwant %s",
string(outputEvent.NewRoomEvent.Event.JSON()),
string(lastEv.JSON()),
)
}
if len(outputEvent.NewRoomEvent.AddStateEvents) != len(stateIDs) {
t.Errorf("Output event is missing room state events themselves, got %d want %d", len(outputEvent.NewRoomEvent.AddStateEvents), len(stateIDs))
}
// make sure the state got overwritten, check the room name
hasRoomName := false
for _, ev := range outputEvent.NewRoomEvent.AddStateEvents {
if ev.Type() == "m.room.name" {
hasRoomName = string(ev.Content()) == `{"name":"Room Name 2"}`
}
}
if !hasRoomName {
t.Errorf("Output event did not overwrite room state")
}
}

View file

@ -159,7 +159,7 @@ func (v StateResolution) LoadCombinedStateAfterEvents(
} }
fullState = append(fullState, entries...) fullState = append(fullState, entries...)
} }
if prevState.IsStateEvent() { if prevState.IsStateEvent() && !prevState.IsRejected {
// If the prev event was a state event then add an entry for the event itself // If the prev event was a state event then add an entry for the event itself
// so that we get the state after the event rather than the state before. // so that we get the state after the event rather than the state before.
fullState = append(fullState, prevState.StateEntry) fullState = append(fullState, prevState.StateEntry)
@ -523,6 +523,7 @@ func init() {
func (v StateResolution) CalculateAndStoreStateBeforeEvent( func (v StateResolution) CalculateAndStoreStateBeforeEvent(
ctx context.Context, ctx context.Context,
event gomatrixserverlib.Event, event gomatrixserverlib.Event,
isRejected bool,
) (types.StateSnapshotNID, error) { ) (types.StateSnapshotNID, error) {
// Load the state at the prev events. // Load the state at the prev events.
prevEventRefs := event.PrevEvents() prevEventRefs := event.PrevEvents()
@ -561,7 +562,7 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents(
if len(prevStates) == 1 { if len(prevStates) == 1 {
prevState := prevStates[0] prevState := prevStates[0]
if prevState.EventStateKeyNID == 0 { if prevState.EventStateKeyNID == 0 || prevState.IsRejected {
// 3) None of the previous events were state events and they all // 3) None of the previous events were state events and they all
// have the same state, so this event has exactly the same state // have the same state, so this event has exactly the same state
// as the previous events. // as the previous events.

View file

@ -70,6 +70,7 @@ type Database interface {
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error. // Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.
StoreEvent( StoreEvent(
ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, ctx context.Context, event gomatrixserverlib.Event, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
isRejected bool,
) (types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) ) (types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error)
// Look up the state entries for a list of string event IDs // Look up the state entries for a list of string event IDs
// Returns an error if the there is an error talking to the database // Returns an error if the there is an error talking to the database

View file

@ -65,13 +65,14 @@ CREATE TABLE IF NOT EXISTS roomserver_events (
-- Needed for setting reference hashes when sending new events. -- Needed for setting reference hashes when sending new events.
reference_sha256 BYTEA NOT NULL, reference_sha256 BYTEA NOT NULL,
-- A list of numeric IDs for events that can authenticate this event. -- A list of numeric IDs for events that can authenticate this event.
auth_event_nids BIGINT[] NOT NULL auth_event_nids BIGINT[] NOT NULL,
is_rejected BOOLEAN NOT NULL DEFAULT FALSE
); );
` `
const insertEventSQL = "" + const insertEventSQL = "" +
"INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth)" + "INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7)" + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" +
" ON CONFLICT ON CONSTRAINT roomserver_event_id_unique" + " ON CONFLICT ON CONSTRAINT roomserver_event_id_unique" +
" DO NOTHING" + " DO NOTHING" +
" RETURNING event_nid, state_snapshot_nid" " RETURNING event_nid, state_snapshot_nid"
@ -88,7 +89,7 @@ const bulkSelectStateEventByIDSQL = "" +
" ORDER BY event_type_nid, event_state_key_nid ASC" " ORDER BY event_type_nid, event_state_key_nid ASC"
const bulkSelectStateAtEventByIDSQL = "" + const bulkSelectStateAtEventByIDSQL = "" +
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid FROM roomserver_events" + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" +
" WHERE event_id = ANY($1)" " WHERE event_id = ANY($1)"
const updateEventStateSQL = "" + const updateEventStateSQL = "" +
@ -174,12 +175,14 @@ func (s *eventStatements) InsertEvent(
referenceSHA256 []byte, referenceSHA256 []byte,
authEventNIDs []types.EventNID, authEventNIDs []types.EventNID,
depth int64, depth int64,
isRejected bool,
) (types.EventNID, types.StateSnapshotNID, error) { ) (types.EventNID, types.StateSnapshotNID, error) {
var eventNID int64 var eventNID int64
var stateNID int64 var stateNID int64
err := s.insertEventStmt.QueryRowContext( err := s.insertEventStmt.QueryRowContext(
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
isRejected,
).Scan(&eventNID, &stateNID) ).Scan(&eventNID, &stateNID)
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
} }
@ -255,6 +258,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
&result.EventStateKeyNID, &result.EventStateKeyNID,
&result.EventNID, &result.EventNID,
&result.BeforeStateSnapshotNID, &result.BeforeStateSnapshotNID,
&result.IsRejected,
); err != nil { ); err != nil {
return nil, err return nil, err
} }

View file

@ -320,9 +320,14 @@ func (d *Database) Events(
if err != nil { if err != nil {
return nil, err return nil, err
} }
roomVersion, err = d.RoomsTable.SelectRoomVersionForRoomNID(ctx, roomNID) if roomID, ok := d.Cache.GetRoomServerRoomID(roomNID); ok {
if err != nil { roomVersion, _ = d.Cache.GetRoomVersion(roomID)
return nil, err }
if roomVersion == "" {
roomVersion, err = d.RoomsTable.SelectRoomVersionForRoomNID(ctx, roomNID)
if err != nil {
return nil, err
}
} }
result.Event, err = gomatrixserverlib.NewEventFromTrustedJSON( result.Event, err = gomatrixserverlib.NewEventFromTrustedJSON(
eventJSON.EventJSON, false, roomVersion, eventJSON.EventJSON, false, roomVersion,
@ -382,7 +387,7 @@ func (d *Database) GetLatestEventsForUpdate(
// nolint:gocyclo // nolint:gocyclo
func (d *Database) StoreEvent( func (d *Database) StoreEvent(
ctx context.Context, event gomatrixserverlib.Event, ctx context.Context, event gomatrixserverlib.Event,
txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID, isRejected bool,
) (types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) { ) (types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
var ( var (
roomNID types.RoomNID roomNID types.RoomNID
@ -446,6 +451,7 @@ func (d *Database) StoreEvent(
event.EventReference().EventSHA256, event.EventReference().EventSHA256,
authEventNIDs, authEventNIDs,
event.Depth(), event.Depth(),
isRejected,
); err != nil { ); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
// We've already inserted the event so select the numeric event ID // We've already inserted the event so select the numeric event ID
@ -459,7 +465,9 @@ func (d *Database) StoreEvent(
if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil { if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil {
return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err) return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err)
} }
redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event) if !isRejected { // ignore rejected redaction events
redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event)
}
return nil return nil
}) })
if err != nil { if err != nil {

View file

@ -41,13 +41,14 @@ const eventsSchema = `
depth INTEGER NOT NULL, depth INTEGER NOT NULL,
event_id TEXT NOT NULL UNIQUE, event_id TEXT NOT NULL UNIQUE,
reference_sha256 BLOB NOT NULL, reference_sha256 BLOB NOT NULL,
auth_event_nids TEXT NOT NULL DEFAULT '[]' auth_event_nids TEXT NOT NULL DEFAULT '[]',
is_rejected BOOLEAN NOT NULL DEFAULT FALSE
); );
` `
const insertEventSQL = ` const insertEventSQL = `
INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth) INSERT INTO roomserver_events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids, depth, is_rejected)
VALUES ($1, $2, $3, $4, $5, $6, $7) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT DO NOTHING; ON CONFLICT DO NOTHING;
` `
@ -63,7 +64,7 @@ const bulkSelectStateEventByIDSQL = "" +
" ORDER BY event_type_nid, event_state_key_nid ASC" " ORDER BY event_type_nid, event_state_key_nid ASC"
const bulkSelectStateAtEventByIDSQL = "" + const bulkSelectStateAtEventByIDSQL = "" +
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid FROM roomserver_events" + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" +
" WHERE event_id IN ($1)" " WHERE event_id IN ($1)"
const updateEventStateSQL = "" + const updateEventStateSQL = "" +
@ -150,13 +151,14 @@ func (s *eventStatements) InsertEvent(
referenceSHA256 []byte, referenceSHA256 []byte,
authEventNIDs []types.EventNID, authEventNIDs []types.EventNID,
depth int64, depth int64,
isRejected bool,
) (types.EventNID, types.StateSnapshotNID, error) { ) (types.EventNID, types.StateSnapshotNID, error) {
// attempt to insert: the last_row_id is the event NID // attempt to insert: the last_row_id is the event NID
var eventNID int64 var eventNID int64
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
result, err := insertStmt.ExecContext( result, err := insertStmt.ExecContext(
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth, isRejected,
) )
if err != nil { if err != nil {
return 0, 0, err return 0, 0, err
@ -261,6 +263,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
&result.EventStateKeyNID, &result.EventStateKeyNID,
&result.EventNID, &result.EventNID,
&result.BeforeStateSnapshotNID, &result.BeforeStateSnapshotNID,
&result.IsRejected,
); err != nil { ); err != nil {
return nil, err return nil, err
} }

View file

@ -18,6 +18,8 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"strings"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
@ -25,10 +27,15 @@ import (
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
// TODO: previous_reference_sha256 was NOT NULL before but it broke sytest because
// sytest sends no SHA256 sums in the prev_events references in the soft-fail tests.
// In Postgres an empty BYTEA field is not NULL so it's fine there. In SQLite it
// seems to care that it's empty and therefore hits a NOT NULL constraint on insert.
// We should really work out what the right thing to do here is.
const previousEventSchema = ` const previousEventSchema = `
CREATE TABLE IF NOT EXISTS roomserver_previous_events ( CREATE TABLE IF NOT EXISTS roomserver_previous_events (
previous_event_id TEXT NOT NULL, previous_event_id TEXT NOT NULL,
previous_reference_sha256 BLOB NOT NULL, previous_reference_sha256 BLOB,
event_nids TEXT NOT NULL, event_nids TEXT NOT NULL,
UNIQUE (previous_event_id, previous_reference_sha256) UNIQUE (previous_event_id, previous_reference_sha256)
); );
@ -45,6 +52,11 @@ const insertPreviousEventSQL = `
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
` `
const selectPreviousEventNIDsSQL = `
SELECT event_nids FROM roomserver_previous_events
WHERE previous_event_id = $1 AND previous_reference_sha256 = $2
`
// Check if the event is referenced by another event in the table. // Check if the event is referenced by another event in the table.
// This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room. // This should only be done while holding a "FOR UPDATE" lock on the row in the rooms table for this room.
const selectPreviousEventExistsSQL = ` const selectPreviousEventExistsSQL = `
@ -55,6 +67,7 @@ const selectPreviousEventExistsSQL = `
type previousEventStatements struct { type previousEventStatements struct {
db *sql.DB db *sql.DB
insertPreviousEventStmt *sql.Stmt insertPreviousEventStmt *sql.Stmt
selectPreviousEventNIDsStmt *sql.Stmt
selectPreviousEventExistsStmt *sql.Stmt selectPreviousEventExistsStmt *sql.Stmt
} }
@ -69,6 +82,7 @@ func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
return s, shared.StatementList{ return s, shared.StatementList{
{&s.insertPreviousEventStmt, insertPreviousEventSQL}, {&s.insertPreviousEventStmt, insertPreviousEventSQL},
{&s.selectPreviousEventNIDsStmt, selectPreviousEventNIDsSQL},
{&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL}, {&s.selectPreviousEventExistsStmt, selectPreviousEventExistsSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -80,9 +94,28 @@ func (s *previousEventStatements) InsertPreviousEvent(
previousEventReferenceSHA256 []byte, previousEventReferenceSHA256 []byte,
eventNID types.EventNID, eventNID types.EventNID,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt) var eventNIDs string
_, err := stmt.ExecContext( eventNIDAsString := fmt.Sprintf("%d", eventNID)
ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID), selectStmt := sqlutil.TxStmt(txn, s.selectPreviousEventExistsStmt)
err := selectStmt.QueryRowContext(ctx, previousEventID, previousEventReferenceSHA256).Scan(&eventNIDs)
if err != sql.ErrNoRows {
return fmt.Errorf("selectStmt.QueryRowContext.Scan: %w", err)
}
var nids []string
if eventNIDs != "" {
nids = strings.Split(eventNIDs, ",")
for _, nid := range nids {
if nid == eventNIDAsString {
return nil
}
}
eventNIDs = strings.Join(append(nids, eventNIDAsString), ",")
} else {
eventNIDs = eventNIDAsString
}
insertStmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
_, err = insertStmt.ExecContext(
ctx, previousEventID, previousEventReferenceSHA256, eventNIDs,
) )
return err return err
} }

View file

@ -34,7 +34,10 @@ type EventStateKeys interface {
} }
type Events interface { type Events interface {
InsertEvent(c context.Context, txn *sql.Tx, i types.RoomNID, j types.EventTypeNID, k types.EventStateKeyNID, eventID string, referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64) (types.EventNID, types.StateSnapshotNID, error) InsertEvent(
ctx context.Context, txn *sql.Tx, i types.RoomNID, j types.EventTypeNID, k types.EventStateKeyNID, eventID string,
referenceSHA256 []byte, authEventNIDs []types.EventNID, depth int64, isRejected bool,
) (types.EventNID, types.StateSnapshotNID, error)
SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error) SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error)
// bulkSelectStateEventByID lookups a list of state events by event ID. // bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError // If any of the requested events are missing from the database it returns a types.MissingEventError

View file

@ -16,6 +16,8 @@
package types package types
import ( import (
"sort"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -72,6 +74,25 @@ func (a StateEntry) LessThan(b StateEntry) bool {
return a.EventNID < b.EventNID return a.EventNID < b.EventNID
} }
// Deduplicate takes a set of state entries and ensures that there are no
// duplicate (event type, state key) tuples. If there are then we dedupe
// them, making sure that the latest/highest NIDs are always chosen.
func DeduplicateStateEntries(a []StateEntry) []StateEntry {
if len(a) < 2 {
return a
}
sort.SliceStable(a, func(i, j int) bool {
return a[i].LessThan(a[j])
})
for i := 0; i < len(a)-1; i++ {
if a[i].StateKeyTuple == a[i+1].StateKeyTuple {
a = append(a[:i], a[i+1:]...)
i--
}
}
return a
}
// StateAtEvent is the state before and after a matrix event. // StateAtEvent is the state before and after a matrix event.
type StateAtEvent struct { type StateAtEvent struct {
// Should this state overwrite the latest events and memberships of the room? // Should this state overwrite the latest events and memberships of the room?
@ -80,6 +101,9 @@ type StateAtEvent struct {
Overwrite bool Overwrite bool
// The state before the event. // The state before the event.
BeforeStateSnapshotNID StateSnapshotNID BeforeStateSnapshotNID StateSnapshotNID
// True if this StateEntry is rejected. State resolution should then treat this
// StateEntry as being a message event (not a state event).
IsRejected bool
// The state entry for the event itself, allows us to calculate the state after the event. // The state entry for the event itself, allows us to calculate the state after the event.
StateEntry StateEntry
} }

View file

@ -0,0 +1,26 @@
package types
import (
"testing"
)
func TestDeduplicateStateEntries(t *testing.T) {
entries := []StateEntry{
{StateKeyTuple{1, 1}, 1},
{StateKeyTuple{1, 1}, 2},
{StateKeyTuple{1, 1}, 3},
{StateKeyTuple{2, 2}, 4},
{StateKeyTuple{2, 3}, 5},
{StateKeyTuple{3, 3}, 6},
}
expected := []EventNID{3, 4, 5, 6}
entries = DeduplicateStateEntries(entries)
if len(entries) != 4 {
t.Fatalf("Expected 4 entries, got %d entries", len(entries))
}
for i, v := range entries {
if v.EventNID != expected[i] {
t.Fatalf("Expected position %d to be %d but got %d", i, expected[i], v.EventNID)
}
}
}

View file

@ -20,7 +20,7 @@ type ServerKeyAPI struct {
ServerKeyValidity time.Duration ServerKeyValidity time.Duration
OurKeyRing gomatrixserverlib.KeyRing OurKeyRing gomatrixserverlib.KeyRing
FedClient *gomatrixserverlib.FederationClient FedClient gomatrixserverlib.KeyClient
} }
func (s *ServerKeyAPI) KeyRing() *gomatrixserverlib.KeyRing { func (s *ServerKeyAPI) KeyRing() *gomatrixserverlib.KeyRing {

View file

@ -26,7 +26,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.ServerKeyInternalAPI, cach
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI( func NewInternalAPI(
cfg *config.ServerKeyAPI, cfg *config.ServerKeyAPI,
fedClient *gomatrixserverlib.FederationClient, fedClient gomatrixserverlib.KeyClient,
caches *caching.Caches, caches *caching.Caches,
) api.ServerKeyInternalAPI { ) api.ServerKeyInternalAPI {
innerDB, err := storage.NewDatabase( innerDB, err := storage.NewDatabase(
@ -53,7 +53,7 @@ func NewInternalAPI(
OurKeyRing: gomatrixserverlib.KeyRing{ OurKeyRing: gomatrixserverlib.KeyRing{
KeyFetchers: []gomatrixserverlib.KeyFetcher{ KeyFetchers: []gomatrixserverlib.KeyFetcher{
&gomatrixserverlib.DirectKeyFetcher{ &gomatrixserverlib.DirectKeyFetcher{
Client: fedClient.Client, Client: fedClient,
}, },
}, },
KeyDatabase: serverKeyDB, KeyDatabase: serverKeyDB,
@ -65,7 +65,7 @@ func NewInternalAPI(
perspective := &gomatrixserverlib.PerspectiveKeyFetcher{ perspective := &gomatrixserverlib.PerspectiveKeyFetcher{
PerspectiveServerName: ps.ServerName, PerspectiveServerName: ps.ServerName,
PerspectiveServerKeys: map[gomatrixserverlib.KeyID]ed25519.PublicKey{}, PerspectiveServerKeys: map[gomatrixserverlib.KeyID]ed25519.PublicKey{},
Client: fedClient.Client, Client: fedClient,
} }
for _, key := range ps.Keys { for _, key := range ps.Keys {

View file

@ -18,9 +18,8 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"strings" "fmt"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -88,48 +87,50 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
ctx context.Context, ctx context.Context,
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
var nameAndKeyIDs []string nameAndKeyIDs := make([]string, 0, len(requests))
for request := range requests { for request := range requests {
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request)) nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
} }
results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, len(requests))
query := strings.Replace(bulkSelectServerKeysSQL, "($1)", sqlutil.QueryVariadic(len(nameAndKeyIDs)), 1)
iKeyIDs := make([]interface{}, len(nameAndKeyIDs)) iKeyIDs := make([]interface{}, len(nameAndKeyIDs))
for i, v := range nameAndKeyIDs { for i, v := range nameAndKeyIDs {
iKeyIDs[i] = v iKeyIDs[i] = v
} }
rows, err := s.db.QueryContext(ctx, query, iKeyIDs...) err := sqlutil.RunLimitedVariablesQuery(
ctx, bulkSelectServerKeysSQL, s.db, iKeyIDs, sqlutil.SQLite3MaxVariables,
func(rows *sql.Rows) error {
for rows.Next() {
var serverName string
var keyID string
var key string
var validUntilTS int64
var expiredTS int64
if err := rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
return fmt.Errorf("bulkSelectServerKeys: %v", err)
}
r := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: gomatrixserverlib.ServerName(serverName),
KeyID: gomatrixserverlib.KeyID(keyID),
}
vk := gomatrixserverlib.VerifyKey{}
err := vk.Key.Decode(key)
if err != nil {
return fmt.Errorf("bulkSelectServerKeys: %v", err)
}
results[r] = gomatrixserverlib.PublicKeyLookupResult{
VerifyKey: vk,
ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS),
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
}
}
return nil
},
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectServerKeys: rows.close() failed")
results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
for rows.Next() {
var serverName string
var keyID string
var key string
var validUntilTS int64
var expiredTS int64
if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
return nil, err
}
r := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: gomatrixserverlib.ServerName(serverName),
KeyID: gomatrixserverlib.KeyID(keyID),
}
vk := gomatrixserverlib.VerifyKey{}
err = vk.Key.Decode(key)
if err != nil {
return nil, err
}
results[r] = gomatrixserverlib.PublicKeyLookupResult{
VerifyKey: vk,
ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS),
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
}
}
return results, nil return results, nil
} }

View file

@ -149,6 +149,12 @@ func (s *OutputRoomEventConsumer) onNewRoomEvent(
} }
} }
if msg.RewritesState {
if err = s.db.PurgeRoom(ctx, ev.RoomID()); err != nil {
return fmt.Errorf("s.db.PurgeRoom: %w", err)
}
}
pduPos, err := s.db.WriteEvent( pduPos, err := s.db.WriteEvent(
ctx, ctx,
&ev, &ev,

View file

@ -43,6 +43,9 @@ type Database interface {
// Returns an error if there was a problem inserting this event. // Returns an error if there was a problem inserting this event.
WriteEvent(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, addStateEvents []gomatrixserverlib.HeaderedEvent, WriteEvent(ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, addStateEvents []gomatrixserverlib.HeaderedEvent,
addStateEventIDs []string, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool) (types.StreamPosition, error) addStateEventIDs []string, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool) (types.StreamPosition, error)
// PurgeRoom completely purges room state from the sync API. This is done when
// receiving an output event that completely resets the state.
PurgeRoom(ctx context.Context, roomID string) error
// GetStateEvent returns the Matrix state event of a given type for a given room with a given state key // GetStateEvent returns the Matrix state event of a given type for a given room with a given state key
// If no event could be found, returns nil // If no event could be found, returns nil
// If there was an issue during the retrieval, returns an error // If there was an issue during the retrieval, returns an error

View file

@ -19,6 +19,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
) )
@ -46,10 +47,14 @@ const selectBackwardExtremitiesForRoomSQL = "" +
const deleteBackwardExtremitySQL = "" + const deleteBackwardExtremitySQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
const deleteBackwardExtremitiesForRoomSQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1"
type backwardExtremitiesStatements struct { type backwardExtremitiesStatements struct {
insertBackwardExtremityStmt *sql.Stmt insertBackwardExtremityStmt *sql.Stmt
selectBackwardExtremitiesForRoomStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt
deleteBackwardExtremityStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt
deleteBackwardExtremitiesForRoomStmt *sql.Stmt
} }
func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
@ -67,6 +72,9 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
return nil, err return nil, err
} }
if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -105,3 +113,10 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
_, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) _, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
return return
} }
func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -69,6 +69,9 @@ const upsertRoomStateSQL = "" +
const deleteRoomStateByEventIDSQL = "" + const deleteRoomStateByEventIDSQL = "" +
"DELETE FROM syncapi_current_room_state WHERE event_id = $1" "DELETE FROM syncapi_current_room_state WHERE event_id = $1"
const DeleteRoomStateForRoomSQL = "" +
"DELETE FROM syncapi_current_room_state WHERE event_id = $1"
const selectRoomIDsWithMembershipSQL = "" + const selectRoomIDsWithMembershipSQL = "" +
"SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
@ -98,6 +101,7 @@ const selectEventsWithEventIDsSQL = "" +
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
upsertRoomStateStmt *sql.Stmt upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt
DeleteRoomStateForRoomStmt *sql.Stmt
selectRoomIDsWithMembershipStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt
selectCurrentStateStmt *sql.Stmt selectCurrentStateStmt *sql.Stmt
selectJoinedUsersStmt *sql.Stmt selectJoinedUsersStmt *sql.Stmt
@ -117,6 +121,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil {
return nil, err return nil, err
} }
if s.DeleteRoomStateForRoomStmt, err = db.Prepare(DeleteRoomStateForRoomSQL); err != nil {
return nil, err
}
if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil {
return nil, err return nil, err
} }
@ -214,6 +221,14 @@ func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
return err return err
} }
func (s *currentRoomStateStatements) DeleteRoomStateForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
stmt := sqlutil.TxStmt(txn, s.DeleteRoomStateForRoomStmt)
_, err := stmt.ExecContext(ctx, roomID)
return err
}
func (s *currentRoomStateStatements) UpsertRoomState( func (s *currentRoomStateStatements) UpsertRoomState(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
event gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition, event gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition,

View file

@ -115,6 +115,9 @@ const selectStateInRangeSQL = "" +
" ORDER BY id ASC" + " ORDER BY id ASC" +
" LIMIT $8" " LIMIT $8"
const deleteEventsForRoomSQL = "" +
"DELETE FROM syncapi_output_room_events WHERE room_id = $1"
type outputRoomEventsStatements struct { type outputRoomEventsStatements struct {
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt selectEventsStmt *sql.Stmt
@ -124,6 +127,7 @@ type outputRoomEventsStatements struct {
selectEarlyEventsStmt *sql.Stmt selectEarlyEventsStmt *sql.Stmt
selectStateInRangeStmt *sql.Stmt selectStateInRangeStmt *sql.Stmt
updateEventJSONStmt *sql.Stmt updateEventJSONStmt *sql.Stmt
deleteEventsForRoomStmt *sql.Stmt
} }
func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
@ -156,6 +160,9 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
if s.updateEventJSONStmt, err = db.Prepare(updateEventJSONSQL); err != nil { if s.updateEventJSONStmt, err = db.Prepare(updateEventJSONSQL); err != nil {
return nil, err return nil, err
} }
if s.deleteEventsForRoomStmt, err = db.Prepare(deleteEventsForRoomSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -395,6 +402,13 @@ func (s *outputRoomEventsStatements) SelectEvents(
return rowsToStreamEvents(rows) return rowsToStreamEvents(rows)
} }
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteEventsForRoomStmt).ExecContext(ctx, roomID)
return err
}
func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
var result []types.StreamEvent var result []types.StreamEvent
for rows.Next() { for rows.Next() {

View file

@ -19,6 +19,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -71,12 +72,16 @@ const selectMaxPositionInTopologySQL = "" +
"SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" + "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" +
") ORDER BY stream_position DESC LIMIT 1" ") ORDER BY stream_position DESC LIMIT 1"
const deleteTopologyForRoomSQL = "" +
"DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1"
type outputRoomEventsTopologyStatements struct { type outputRoomEventsTopologyStatements struct {
insertEventInTopologyStmt *sql.Stmt insertEventInTopologyStmt *sql.Stmt
selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeASCStmt *sql.Stmt
selectEventIDsInRangeDESCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt
selectPositionInTopologyStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt
selectMaxPositionInTopologyStmt *sql.Stmt selectMaxPositionInTopologyStmt *sql.Stmt
deleteTopologyForRoomStmt *sql.Stmt
} }
func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) {
@ -100,6 +105,9 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) {
if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil {
return nil, err return nil, err
} }
if s.deleteTopologyForRoomStmt, err = db.Prepare(deleteTopologyForRoomSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -167,3 +175,10 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
return return
} }
func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -276,6 +276,29 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e
return nil return nil
} }
func (d *Database) PurgeRoom(
ctx context.Context, roomID string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// If the event is a create event then we'll delete all of the existing
// data for the room. The only reason that a create event would be replayed
// to us in this way is if we're about to receive the entire room state.
if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err)
}
if err := d.OutputEvents.DeleteEventsForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.Events.DeleteEventsForRoom: %w", err)
}
if err := d.Topology.DeleteTopologyForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.Topology.DeleteTopologyForRoom: %w", err)
}
if err := d.BackwardExtremities.DeleteBackwardExtremitiesForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.BackwardExtremities.DeleteBackwardExtremitiesForRoom: %w", err)
}
return nil
})
}
func (d *Database) WriteEvent( func (d *Database) WriteEvent(
ctx context.Context, ctx context.Context,
ev *gomatrixserverlib.HeaderedEvent, ev *gomatrixserverlib.HeaderedEvent,

View file

@ -19,6 +19,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
) )
@ -46,11 +47,15 @@ const selectBackwardExtremitiesForRoomSQL = "" +
const deleteBackwardExtremitySQL = "" + const deleteBackwardExtremitySQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
const deleteBackwardExtremitiesForRoomSQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1"
type backwardExtremitiesStatements struct { type backwardExtremitiesStatements struct {
db *sql.DB db *sql.DB
insertBackwardExtremityStmt *sql.Stmt insertBackwardExtremityStmt *sql.Stmt
selectBackwardExtremitiesForRoomStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt
deleteBackwardExtremityStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt
deleteBackwardExtremitiesForRoomStmt *sql.Stmt
} }
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
@ -70,6 +75,9 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil {
return nil, err return nil, err
} }
if s.deleteBackwardExtremitiesForRoomStmt, err = db.Prepare(deleteBackwardExtremitiesForRoomSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -108,3 +116,10 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
_, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) _, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
return err return err
} }
func (s *backwardExtremitiesStatements) DeleteBackwardExtremitiesForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremitiesForRoomStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -51,12 +51,15 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_current_room_s
const upsertRoomStateSQL = "" + const upsertRoomStateSQL = "" +
"INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, headered_event_json, membership, added_at)" + "INSERT INTO syncapi_current_room_state (room_id, event_id, type, sender, contains_url, state_key, headered_event_json, membership, added_at)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" +
" ON CONFLICT (event_id, room_id, type, sender, contains_url)" + " ON CONFLICT (room_id, type, state_key)" +
" DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, headered_event_json = $7, membership = $8, added_at = $9" " DO UPDATE SET event_id = $2, sender=$4, contains_url=$5, headered_event_json = $7, membership = $8, added_at = $9"
const deleteRoomStateByEventIDSQL = "" + const deleteRoomStateByEventIDSQL = "" +
"DELETE FROM syncapi_current_room_state WHERE event_id = $1" "DELETE FROM syncapi_current_room_state WHERE event_id = $1"
const DeleteRoomStateForRoomSQL = "" +
"DELETE FROM syncapi_current_room_state WHERE event_id = $1"
const selectRoomIDsWithMembershipSQL = "" + const selectRoomIDsWithMembershipSQL = "" +
"SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
@ -88,6 +91,7 @@ type currentRoomStateStatements struct {
streamIDStatements *streamIDStatements streamIDStatements *streamIDStatements
upsertRoomStateStmt *sql.Stmt upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt
DeleteRoomStateForRoomStmt *sql.Stmt
selectRoomIDsWithMembershipStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt
selectCurrentStateStmt *sql.Stmt selectCurrentStateStmt *sql.Stmt
selectJoinedUsersStmt *sql.Stmt selectJoinedUsersStmt *sql.Stmt
@ -109,6 +113,9 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (t
if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil {
return nil, err return nil, err
} }
if s.DeleteRoomStateForRoomStmt, err = db.Prepare(DeleteRoomStateForRoomSQL); err != nil {
return nil, err
}
if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil {
return nil, err return nil, err
} }
@ -203,6 +210,14 @@ func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
return err return err
} }
func (s *currentRoomStateStatements) DeleteRoomStateForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
stmt := sqlutil.TxStmt(txn, s.DeleteRoomStateForRoomStmt)
_, err := stmt.ExecContext(ctx, roomID)
return err
}
func (s *currentRoomStateStatements) UpsertRoomState( func (s *currentRoomStateStatements) UpsertRoomState(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
event gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition, event gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition,

View file

@ -103,6 +103,9 @@ const selectStateInRangeSQL = "" +
" ORDER BY id ASC" + " ORDER BY id ASC" +
" LIMIT $8" // limit " LIMIT $8" // limit
const deleteEventsForRoomSQL = "" +
"DELETE FROM syncapi_output_room_events WHERE room_id = $1"
type outputRoomEventsStatements struct { type outputRoomEventsStatements struct {
db *sql.DB db *sql.DB
streamIDStatements *streamIDStatements streamIDStatements *streamIDStatements
@ -114,6 +117,7 @@ type outputRoomEventsStatements struct {
selectEarlyEventsStmt *sql.Stmt selectEarlyEventsStmt *sql.Stmt
selectStateInRangeStmt *sql.Stmt selectStateInRangeStmt *sql.Stmt
updateEventJSONStmt *sql.Stmt updateEventJSONStmt *sql.Stmt
deleteEventsForRoomStmt *sql.Stmt
} }
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) { func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
@ -149,6 +153,9 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even
if s.updateEventJSONStmt, err = db.Prepare(updateEventJSONSQL); err != nil { if s.updateEventJSONStmt, err = db.Prepare(updateEventJSONSQL); err != nil {
return nil, err return nil, err
} }
if s.deleteEventsForRoomStmt, err = db.Prepare(deleteEventsForRoomSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -410,6 +417,13 @@ func (s *outputRoomEventsStatements) SelectEvents(
return returnEvents, nil return returnEvents, nil
} }
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteEventsForRoomStmt).ExecContext(ctx, roomID)
return err
}
func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) {
var result []types.StreamEvent var result []types.StreamEvent
for rows.Next() { for rows.Next() {

View file

@ -65,6 +65,9 @@ const selectMaxPositionInTopologySQL = "" +
"SELECT MAX(topological_position), stream_position FROM syncapi_output_room_events_topology" + "SELECT MAX(topological_position), stream_position FROM syncapi_output_room_events_topology" +
" WHERE room_id = $1 ORDER BY stream_position DESC" " WHERE room_id = $1 ORDER BY stream_position DESC"
const deleteTopologyForRoomSQL = "" +
"DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1"
type outputRoomEventsTopologyStatements struct { type outputRoomEventsTopologyStatements struct {
db *sql.DB db *sql.DB
insertEventInTopologyStmt *sql.Stmt insertEventInTopologyStmt *sql.Stmt
@ -72,6 +75,7 @@ type outputRoomEventsTopologyStatements struct {
selectEventIDsInRangeDESCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt
selectPositionInTopologyStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt
selectMaxPositionInTopologyStmt *sql.Stmt selectMaxPositionInTopologyStmt *sql.Stmt
deleteTopologyForRoomStmt *sql.Stmt
} }
func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
@ -97,6 +101,9 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil {
return nil, err return nil, err
} }
if s.deleteTopologyForRoomStmt, err = db.Prepare(deleteTopologyForRoomSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -164,3 +171,10 @@ func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
return return
} }
func (s *outputRoomEventsTopologyStatements) DeleteTopologyForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteTopologyForRoomStmt).ExecContext(ctx, roomID)
return err
}

View file

@ -60,6 +60,8 @@ type Events interface {
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int) ([]types.StreamEvent, error) SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int) ([]types.StreamEvent, error)
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error) SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error)
UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error
// DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely.
DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
} }
// Topology keeps track of the depths and stream positions for all events. // Topology keeps track of the depths and stream positions for all events.
@ -77,6 +79,8 @@ type Topology interface {
SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error) SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error)
// SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position. // SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position.
SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error) SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error)
// DeleteTopologyForRoom removes all topological information for a room. This should only be done when removing the room entirely.
DeleteTopologyForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
} }
type CurrentRoomState interface { type CurrentRoomState interface {
@ -84,6 +88,7 @@ type CurrentRoomState interface {
SelectEventsWithEventIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error) SelectEventsWithEventIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error)
UpsertRoomState(ctx context.Context, txn *sql.Tx, event gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition) error UpsertRoomState(ctx context.Context, txn *sql.Tx, event gomatrixserverlib.HeaderedEvent, membership *string, addedAt types.StreamPosition) error
DeleteRoomStateByEventID(ctx context.Context, txn *sql.Tx, eventID string) error DeleteRoomStateByEventID(ctx context.Context, txn *sql.Tx, eventID string) error
DeleteRoomStateForRoom(ctx context.Context, txn *sql.Tx, roomID string) error
// SelectCurrentState returns all the current state events for the given room. // SelectCurrentState returns all the current state events for the given room.
SelectCurrentState(ctx context.Context, txn *sql.Tx, roomID string, stateFilter *gomatrixserverlib.StateFilter) ([]gomatrixserverlib.HeaderedEvent, error) SelectCurrentState(ctx context.Context, txn *sql.Tx, roomID string, stateFilter *gomatrixserverlib.StateFilter) ([]gomatrixserverlib.HeaderedEvent, error)
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
@ -118,6 +123,8 @@ type BackwardsExtremities interface {
SelectBackwardExtremitiesForRoom(ctx context.Context, roomID string) (bwExtrems map[string][]string, err error) SelectBackwardExtremitiesForRoom(ctx context.Context, roomID string) (bwExtrems map[string][]string, err error)
// DeleteBackwardExtremity removes a backwards extremity for a room, if one existed. // DeleteBackwardExtremity removes a backwards extremity for a room, if one existed.
DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error) DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error)
// DeleteBackwardExtremitiesFoorRoomID removes all backward extremities for a room. This should only be done when removing the room entirely.
DeleteBackwardExtremitiesForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error)
} }
// SendToDevice tracks send-to-device messages which are sent to individual // SendToDevice tracks send-to-device messages which are sent to individual

View file

@ -40,11 +40,6 @@ Ignore invite in incremental sync
New room members see their own join event New room members see their own join event
Existing members see new members' join events Existing members see new members' join events
# Blacklisted because the federation work for these hasn't been finished yet.
Can recv device messages over federation
Device messages over federation wake up /sync
Wildcard device messages over federation wake up /sync
# See https://github.com/matrix-org/sytest/pull/901 # See https://github.com/matrix-org/sytest/pull/901
Remote invited user can see room metadata Remote invited user can see room metadata
@ -56,8 +51,5 @@ Inbound federation accepts a second soft-failed event
# Caused by https://github.com/matrix-org/sytest/pull/911 # Caused by https://github.com/matrix-org/sytest/pull/911
Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state
# We don't implement device lists yet
Device list doesn't change if remote server is down
# We don't implement lazy membership loading yet. # We don't implement lazy membership loading yet.
The only membership state included in a gapped incremental sync is for senders in the timeline The only membership state included in a gapped incremental sync is for senders in the timeline

View file

@ -470,4 +470,10 @@ We can't peek into rooms with shared history_visibility
We can't peek into rooms with invited history_visibility We can't peek into rooms with invited history_visibility
We can't peek into rooms with joined history_visibility We can't peek into rooms with joined history_visibility
Local users can peek by room alias Local users can peek by room alias
Peeked rooms only turn up in the sync for the device who peeked them Peeked rooms only turn up in the sync for the device who peeked them
Room state at a rejected message event is the same as its predecessor
Room state at a rejected state event is the same as its predecessor
Inbound federation correctly soft fails events
Inbound federation accepts a second soft-failed event
Federation key API can act as a notary server via a POST request
Federation key API can act as a notary server via a GET request