Check event origin after transaction origin if possible

This commit is contained in:
Neil Alexander 2021-06-22 10:32:20 +01:00
parent 5357df36c9
commit fee5074f15
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944

View file

@ -236,9 +236,6 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res
} }
if err = gomatrixserverlib.VerifyAllEventSignatures(ctx, []*gomatrixserverlib.Event{event}, t.keys); err != nil { if err = gomatrixserverlib.VerifyAllEventSignatures(ctx, []*gomatrixserverlib.Event{event}, t.keys); err != nil {
util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID())
results[event.EventID()] = gomatrixserverlib.PDUResult{
Error: err.Error(),
}
continue continue
} }
pdus = append(pdus, event.Headered(verRes.RoomVersion)) pdus = append(pdus, event.Headered(verRes.RoomVersion))
@ -479,13 +476,16 @@ func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverli
} }
} }
func (t *txnReq) getServers(ctx context.Context, roomID string) []gomatrixserverlib.ServerName { func (t *txnReq) getServers(ctx context.Context, roomID string, eventOrigin gomatrixserverlib.ServerName) []gomatrixserverlib.ServerName {
t.serversMutex.Lock() t.serversMutex.Lock()
defer t.serversMutex.Unlock() defer t.serversMutex.Unlock()
if t.servers != nil { if t.servers != nil {
return t.servers return t.servers
} }
t.servers = []gomatrixserverlib.ServerName{t.Origin} t.servers = []gomatrixserverlib.ServerName{t.Origin} // transaction origin
if eventOrigin != "" {
t.servers = append(t.servers, eventOrigin) // event origin, if specified
}
serverReq := &api.QueryServerJoinedToRoomRequest{ serverReq := &api.QueryServerJoinedToRoomRequest{
RoomID: roomID, RoomID: roomID,
} }
@ -570,7 +570,7 @@ func (t *txnReq) retrieveMissingAuthEvents(
withNextEvent: withNextEvent:
for missingAuthEventID := range missingAuthEvents { for missingAuthEventID := range missingAuthEvents {
withNextServer: withNextServer:
for _, server := range t.getServers(ctx, e.RoomID()) { for _, server := range t.getServers(ctx, e.RoomID(), e.Origin()) {
logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server) logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server)
tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID) tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID)
if err != nil { if err != nil {
@ -582,6 +582,20 @@ withNextEvent:
logger.WithError(err).Warnf("Failed to unmarshal auth event %q", missingAuthEventID) logger.WithError(err).Warnf("Failed to unmarshal auth event %q", missingAuthEventID)
continue withNextServer continue withNextServer
} }
if err = gomatrixserverlib.VerifyAllEventSignatures(ctx, []*gomatrixserverlib.Event{ev}, t.keys); err != nil {
logger.WithError(err).Warnf("Failed to verify signature of auth event %q, dropping...", missingAuthEventID)
// If the response came right from the origin of the event and the
// signature is still wrong then there's no reason to believe that any
// other server will have a better one, so don't bother doing anything
// else - just skip and move onto the next event
if e.Origin() == server {
delete(missingAuthEvents, missingAuthEventID)
continue withNextEvent
}
// Otherwise, it might be that the server we asked has interfered with
// the event so we can just try asking some other server instead
continue withNextServer
}
if err = api.SendInputRoomEvents( if err = api.SendInputRoomEvents(
context.Background(), context.Background(),
t.rsAPI, t.rsAPI,
@ -939,7 +953,7 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Even
} }
var missingResp *gomatrixserverlib.RespMissingEvents var missingResp *gomatrixserverlib.RespMissingEvents
servers := t.getServers(ctx, e.RoomID()) servers := t.getServers(ctx, e.RoomID(), e.Origin())
for _, server := range servers { for _, server := range servers {
var m gomatrixserverlib.RespMissingEvents var m gomatrixserverlib.RespMissingEvents
if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{
@ -1193,7 +1207,7 @@ func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.
} }
var event *gomatrixserverlib.Event var event *gomatrixserverlib.Event
found := false found := false
servers := t.getServers(ctx, roomID) servers := t.getServers(ctx, roomID, "")
for _, serverName := range servers { for _, serverName := range servers {
txn, err := t.federation.GetEvent(ctx, serverName, missingEventID) txn, err := t.federation.GetEvent(ctx, serverName, missingEventID)
if err != nil || len(txn.PDUs) == 0 { if err != nil || len(txn.PDUs) == 0 {