Better mapping of stream positions to topological positions in /messages (#2263)

* Convert stream positions into topological positions for both `from` and `to` in `/messages`

* Hopefully it works now

* Remove unnecessary logging

* Return sane values if `StreamToTopologicalPosition` can't work out the right thing to do

* Revert logging change

* tweaks

* Fix `selectEventIDsInRangeASCSQL`

* Test `Getting messages going forward is limited for a departed room (SPEC-216)` was passing incorrectly so un-whitelist it
This commit is contained in:
Neil Alexander 2022-03-18 10:40:01 +00:00 committed by GitHub
parent 191486438c
commit 475d3c1af9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 128 additions and 84 deletions

View file

@ -41,7 +41,6 @@ type messagesReq struct {
roomID string
from *types.TopologyToken
to *types.TopologyToken
fromStream *types.StreamingToken
device *userapi.Device
wasToProvided bool
backwardOrdering bool
@ -50,7 +49,7 @@ type messagesReq struct {
type messagesResp struct {
Start string `json:"start"`
StartStream string `json:"start_stream,omitempty"` // NOTSPEC: so clients can hit /messages then immediately /sync with a latest sync token
StartStream string `json:"start_stream,omitempty"` // NOTSPEC: used by Cerulean, so clients can hit /messages then immediately /sync with a latest sync token
End string `json:"end"`
Chunk []gomatrixserverlib.ClientEvent `json:"chunk"`
State []gomatrixserverlib.ClientEvent `json:"state"`
@ -93,6 +92,7 @@ func OnIncomingMessagesRequest(
// Pagination tokens.
var fromStream *types.StreamingToken
fromQuery := req.URL.Query().Get("from")
toQuery := req.URL.Query().Get("to")
emptyFromSupplied := fromQuery == ""
if emptyFromSupplied {
// NOTSPEC: We will pretend they used the latest sync token if no ?from= was provided.
@ -101,18 +101,6 @@ func OnIncomingMessagesRequest(
fromQuery = currPos.String()
}
from, err := types.NewTopologyTokenFromString(fromQuery)
if err != nil {
fs, err2 := types.NewStreamTokenFromString(fromQuery)
fromStream = &fs
if err2 != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("Invalid from parameter: " + err2.Error()),
}
}
}
// Direction to return events from.
dir := req.URL.Query().Get("dir")
if dir != "b" && dir != "f" {
@ -125,16 +113,43 @@ func OnIncomingMessagesRequest(
// to have one of the two accepted values (so dir == "f" <=> !backwardOrdering).
backwardOrdering := (dir == "b")
from, err := types.NewTopologyTokenFromString(fromQuery)
if err != nil {
var streamToken types.StreamingToken
if streamToken, err = types.NewStreamTokenFromString(fromQuery); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("Invalid from parameter: " + err.Error()),
}
} else {
fromStream = &streamToken
from, err = db.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, backwardOrdering)
if err != nil {
logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken)
return jsonerror.InternalServerError()
}
}
}
// Pagination tokens. To is optional, and its default value depends on the
// direction ("b" or "f").
var to types.TopologyToken
wasToProvided := true
if s := req.URL.Query().Get("to"); len(s) > 0 {
to, err = types.NewTopologyTokenFromString(s)
if len(toQuery) > 0 {
to, err = types.NewTopologyTokenFromString(toQuery)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("Invalid to parameter: " + err.Error()),
var streamToken types.StreamingToken
if streamToken, err = types.NewStreamTokenFromString(toQuery); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: jsonerror.InvalidArgumentValue("Invalid to parameter: " + err.Error()),
}
} else {
to, err = db.StreamToTopologicalPosition(req.Context(), roomID, streamToken.PDUPosition, !backwardOrdering)
if err != nil {
logrus.WithError(err).Errorf("Failed to get topological position for streaming token %v", streamToken)
return jsonerror.InternalServerError()
}
}
}
} else {
@ -168,7 +183,6 @@ func OnIncomingMessagesRequest(
roomID: roomID,
from: &from,
to: &to,
fromStream: fromStream,
wasToProvided: wasToProvided,
filter: filter,
backwardOrdering: backwardOrdering,
@ -215,7 +229,7 @@ func OnIncomingMessagesRequest(
End: end.String(),
State: state,
}
if emptyFromSupplied {
if fromStream != nil {
res.StartStream = fromStream.String()
}
@ -251,17 +265,9 @@ func (r *messagesReq) retrieveEvents() (
eventFilter := r.filter
// Retrieve the events from the local database.
var streamEvents []types.StreamEvent
if r.fromStream != nil {
toStream := r.to.StreamToken()
streamEvents, err = r.db.GetEventsInStreamingRange(
r.ctx, r.fromStream, &toStream, r.roomID, eventFilter, r.backwardOrdering,
)
} else {
streamEvents, err = r.db.GetEventsInTopologicalRange(
r.ctx, r.from, r.to, r.roomID, eventFilter.Limit, r.backwardOrdering,
)
}
streamEvents, err := r.db.GetEventsInTopologicalRange(
r.ctx, r.from, r.to, r.roomID, eventFilter.Limit, r.backwardOrdering,
)
if err != nil {
err = fmt.Errorf("GetEventsInRange: %w", err)
return