mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-28 16:08:27 +00:00
Add QueryStateAndAuthChainIDs
This commit is contained in:
parent
1827dd7c09
commit
90dd5e6544
8 changed files with 175 additions and 0 deletions
|
@ -75,6 +75,12 @@ type RoomserverInternalAPI interface {
|
||||||
response *QueryLatestEventsAndStateResponse,
|
response *QueryLatestEventsAndStateResponse,
|
||||||
) error
|
) error
|
||||||
|
|
||||||
|
QueryStateAndAuthChainIDs(
|
||||||
|
ctx context.Context,
|
||||||
|
request *QueryStateAndAuthChainIDsRequest,
|
||||||
|
response *QueryStateAndAuthChainIDsResponse,
|
||||||
|
) error
|
||||||
|
|
||||||
// Query the state after a list of events in a room from the room server.
|
// Query the state after a list of events in a room from the room server.
|
||||||
QueryStateAfterEvents(
|
QueryStateAfterEvents(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
|
|
@ -250,6 +250,28 @@ type QueryStateAndAuthChainResponse struct {
|
||||||
AuthChainEvents []*gomatrixserverlib.HeaderedEvent `json:"auth_chain_events"`
|
AuthChainEvents []*gomatrixserverlib.HeaderedEvent `json:"auth_chain_events"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QueryStateAndAuthChainIDsRequest is a request to QueryStateAndAuthChainIDs
|
||||||
|
type QueryStateAndAuthChainIDsRequest struct {
|
||||||
|
// The room ID to query the state in.
|
||||||
|
RoomID string `json:"room_id"`
|
||||||
|
// The list of prev events for the event. Used to calculate the state at
|
||||||
|
// the event.
|
||||||
|
PrevEventIDs []string `json:"prev_event_ids"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryStateAndAuthChainIDsResponse is a response to QueryStateAndAuthChainIDs
|
||||||
|
type QueryStateAndAuthChainIDsResponse struct {
|
||||||
|
// Does the room exist on this roomserver?
|
||||||
|
// If the room doesn't exist this will be false and StateEvents will be empty.
|
||||||
|
RoomExists bool `json:"room_exists"`
|
||||||
|
// The room version of the room.
|
||||||
|
RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"`
|
||||||
|
// The state and auth chain event IDs that were requested.
|
||||||
|
// The lists will be in an arbitrary order.
|
||||||
|
StateEvents []string `json:"state_event_ids"`
|
||||||
|
AuthChainEvents []string `json:"auth_chain_event_ids"`
|
||||||
|
}
|
||||||
|
|
||||||
// QueryRoomVersionCapabilitiesRequest asks for the default room version
|
// QueryRoomVersionCapabilitiesRequest asks for the default room version
|
||||||
type QueryRoomVersionCapabilitiesRequest struct{}
|
type QueryRoomVersionCapabilitiesRequest struct{}
|
||||||
|
|
||||||
|
|
|
@ -526,6 +526,70 @@ func (r *Queryer) QueryStateAndAuthChain(
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QueryStateAndAuthChain implements api.RoomserverInternalAPI
|
||||||
|
func (r *Queryer) QueryStateAndAuthChainIDs(
|
||||||
|
ctx context.Context,
|
||||||
|
request *api.QueryStateAndAuthChainIDsRequest,
|
||||||
|
response *api.QueryStateAndAuthChainIDsResponse,
|
||||||
|
) error {
|
||||||
|
info, err := r.DB.RoomInfo(ctx, request.RoomID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if info == nil || info.IsStub {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
response.RoomExists = true
|
||||||
|
response.RoomVersion = info.RoomVersion
|
||||||
|
|
||||||
|
roomState := state.NewStateResolution(r.DB, *info)
|
||||||
|
prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("r.DB.StateAtEventIDs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
eventNIDs := map[types.EventNID]struct{}{}
|
||||||
|
for _, prevState := range prevStates {
|
||||||
|
var entries []types.StateEntry
|
||||||
|
entries, err = roomState.LoadStateAtSnapshot(ctx, prevState.BeforeStateSnapshotNID)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, entry := range entries {
|
||||||
|
eventNIDs[entry.EventNID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var eventNIDsArray types.EventNIDs
|
||||||
|
for nid := range eventNIDs {
|
||||||
|
eventNIDsArray = append(eventNIDsArray, nid)
|
||||||
|
}
|
||||||
|
|
||||||
|
authEventNIDsArray, err := r.DB.AuthEventNIDs(ctx, eventNIDsArray)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("r.DB.AuthEventNIDs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stateEventIDs, err := r.DB.EventIDs(ctx, eventNIDsArray)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("r.DB.EventIDs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
authEventIDs, err := r.DB.EventIDs(ctx, authEventNIDsArray)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("r.DB.EventIDs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, eventID := range stateEventIDs {
|
||||||
|
response.StateEvents = append(response.StateEvents, eventID)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, eventID := range authEventIDs {
|
||||||
|
response.AuthChainEvents = append(response.AuthChainEvents, eventID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo types.RoomInfo, eventIDs []string) ([]*gomatrixserverlib.Event, error) {
|
func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo types.RoomInfo, eventIDs []string) ([]*gomatrixserverlib.Event, error) {
|
||||||
roomState := state.NewStateResolution(r.DB, roomInfo)
|
roomState := state.NewStateResolution(r.DB, roomInfo)
|
||||||
prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs)
|
prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs)
|
||||||
|
|
|
@ -87,6 +87,8 @@ type Database interface {
|
||||||
// Lookup the event IDs for a batch of event numeric IDs.
|
// Lookup the event IDs for a batch of event numeric IDs.
|
||||||
// Returns an error if the retrieval went wrong.
|
// Returns an error if the retrieval went wrong.
|
||||||
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
|
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
|
||||||
|
// AuthEventNIDs returns the auth event NIDs for the given events.
|
||||||
|
AuthEventNIDs(ctx context.Context, events []types.EventNID) (types.EventNIDs, error)
|
||||||
// Look up the latest events in a room in preparation for an update.
|
// Look up the latest events in a room in preparation for an update.
|
||||||
// The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error.
|
// The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error.
|
||||||
// Returns the latest events in the room and the last eventID sent to the log along with an updater.
|
// Returns the latest events in the room and the last eventID sent to the log along with an updater.
|
||||||
|
|
|
@ -134,6 +134,9 @@ const selectMaxEventDepthSQL = "" +
|
||||||
const selectRoomNIDsForEventNIDsSQL = "" +
|
const selectRoomNIDsForEventNIDsSQL = "" +
|
||||||
"SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid = ANY($1)"
|
"SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid = ANY($1)"
|
||||||
|
|
||||||
|
const bulkSelectEventAuthEventNIDsSQL = "" +
|
||||||
|
"SELECT auth_event_nids FROM roomserver_events WHERE event_nid = ANY($1)"
|
||||||
|
|
||||||
type eventStatements struct {
|
type eventStatements struct {
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
selectEventStmt *sql.Stmt
|
selectEventStmt *sql.Stmt
|
||||||
|
@ -150,6 +153,7 @@ type eventStatements struct {
|
||||||
bulkSelectEventNIDStmt *sql.Stmt
|
bulkSelectEventNIDStmt *sql.Stmt
|
||||||
selectMaxEventDepthStmt *sql.Stmt
|
selectMaxEventDepthStmt *sql.Stmt
|
||||||
selectRoomNIDsForEventNIDsStmt *sql.Stmt
|
selectRoomNIDsForEventNIDsStmt *sql.Stmt
|
||||||
|
bulkSelectEventAuthEventNIDsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func createEventsTable(db *sql.DB) error {
|
func createEventsTable(db *sql.DB) error {
|
||||||
|
@ -176,6 +180,7 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) {
|
||||||
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
|
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
|
||||||
{&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL},
|
{&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL},
|
||||||
{&s.selectRoomNIDsForEventNIDsStmt, selectRoomNIDsForEventNIDsSQL},
|
{&s.selectRoomNIDsForEventNIDsStmt, selectRoomNIDsForEventNIDsSQL},
|
||||||
|
{&s.bulkSelectEventAuthEventNIDsStmt, bulkSelectEventAuthEventNIDsSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -502,6 +507,28 @@ func (s *eventStatements) SelectRoomNIDsForEventNIDs(
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *eventStatements) SelectEventAuthEventNIDs(
|
||||||
|
ctx context.Context, eventNIDs []types.EventNID,
|
||||||
|
) (map[types.EventNID][]types.EventNID, error) {
|
||||||
|
rows, err := s.bulkSelectEventAuthEventNIDsStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomNIDsForEventNIDsStmt: rows.close() failed")
|
||||||
|
result := make(map[types.EventNID][]types.EventNID)
|
||||||
|
for rows.Next() {
|
||||||
|
var eventNID types.EventNID
|
||||||
|
var authEventNIDs pq.Int64Array
|
||||||
|
if err = rows.Scan(&authEventNIDs); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, a := range authEventNIDs {
|
||||||
|
result[eventNID] = append(result[eventNID], types.EventNID(a))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array {
|
func eventNIDsAsArray(eventNIDs []types.EventNID) pq.Int64Array {
|
||||||
nids := make([]int64, len(eventNIDs))
|
nids := make([]int64, len(eventNIDs))
|
||||||
for i := range eventNIDs {
|
for i := range eventNIDs {
|
||||||
|
|
|
@ -292,6 +292,22 @@ func (d *Database) StateEntries(
|
||||||
return lists, nil
|
return lists, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) AuthEventNIDs(
|
||||||
|
ctx context.Context, events []types.EventNID,
|
||||||
|
) (types.EventNIDs, error) {
|
||||||
|
entries, err := d.EventsTable.SelectEventAuthEventNIDs(
|
||||||
|
ctx, events,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("d.EventsTable.SelectEventAuthEventNIDs: %w", err)
|
||||||
|
}
|
||||||
|
var lists types.EventNIDs
|
||||||
|
for _, nids := range entries {
|
||||||
|
lists = append(lists, nids...)
|
||||||
|
}
|
||||||
|
return lists[:util.SortAndUnique(lists)], nil
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
|
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.RoomAliasesTable.InsertRoomAlias(ctx, txn, alias, roomID, creatorUserID)
|
return d.RoomAliasesTable.InsertRoomAlias(ctx, txn, alias, roomID, creatorUserID)
|
||||||
|
|
|
@ -23,6 +23,7 @@ import (
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/lib/pq"
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
|
@ -104,6 +105,9 @@ const selectMaxEventDepthSQL = "" +
|
||||||
const selectRoomNIDsForEventNIDsSQL = "" +
|
const selectRoomNIDsForEventNIDsSQL = "" +
|
||||||
"SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid IN ($1)"
|
"SELECT event_nid, room_nid FROM roomserver_events WHERE event_nid IN ($1)"
|
||||||
|
|
||||||
|
const bulkSelectEventAuthEventNIDsSQL = "" +
|
||||||
|
"SELECT auth_event_nids FROM roomserver_events WHERE event_nid IN ($1)"
|
||||||
|
|
||||||
type eventStatements struct {
|
type eventStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
|
@ -119,6 +123,7 @@ type eventStatements struct {
|
||||||
bulkSelectEventIDStmt *sql.Stmt
|
bulkSelectEventIDStmt *sql.Stmt
|
||||||
bulkSelectEventNIDStmt *sql.Stmt
|
bulkSelectEventNIDStmt *sql.Stmt
|
||||||
//selectRoomNIDsForEventNIDsStmt *sql.Stmt
|
//selectRoomNIDsForEventNIDsStmt *sql.Stmt
|
||||||
|
//bulkSelectEventAuthEventNIDsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func createEventsTable(db *sql.DB) error {
|
func createEventsTable(db *sql.DB) error {
|
||||||
|
@ -145,6 +150,7 @@ func prepareEventsTable(db *sql.DB) (tables.Events, error) {
|
||||||
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
|
{&s.bulkSelectEventIDStmt, bulkSelectEventIDSQL},
|
||||||
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
|
{&s.bulkSelectEventNIDStmt, bulkSelectEventNIDSQL},
|
||||||
//{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL},
|
//{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL},
|
||||||
|
//{&s.bulkSelectEventAuthEventNIDsStmt, bulkSelectEventAuthEventNIDsSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -571,6 +577,37 @@ func (s *eventStatements) SelectRoomNIDsForEventNIDs(
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *eventStatements) SelectEventAuthEventNIDs(
|
||||||
|
ctx context.Context, eventNIDs []types.EventNID,
|
||||||
|
) (map[types.EventNID][]types.EventNID, error) {
|
||||||
|
sqlStr := strings.Replace(bulkSelectEventAuthEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
|
||||||
|
sqlPrep, err := s.db.Prepare(sqlStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||||
|
for i, v := range eventNIDs {
|
||||||
|
iEventNIDs[i] = v
|
||||||
|
}
|
||||||
|
rows, err := sqlPrep.QueryContext(ctx, iEventNIDs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventAuthEventNIDsStmt: rows.close() failed")
|
||||||
|
result := make(map[types.EventNID][]types.EventNID)
|
||||||
|
for rows.Next() {
|
||||||
|
var eventNID types.EventNID
|
||||||
|
var authEventNIDs pq.Int64Array
|
||||||
|
if err = rows.Scan(&authEventNIDs); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, a := range authEventNIDs {
|
||||||
|
result[eventNID] = append(result[eventNID], types.EventNID(a))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
func eventNIDsAsArray(eventNIDs []types.EventNID) string {
|
func eventNIDsAsArray(eventNIDs []types.EventNID) string {
|
||||||
b, _ := json.Marshal(eventNIDs)
|
b, _ := json.Marshal(eventNIDs)
|
||||||
return string(b)
|
return string(b)
|
||||||
|
|
|
@ -61,6 +61,7 @@ type Events interface {
|
||||||
BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error)
|
BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error)
|
||||||
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
|
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
|
||||||
SelectRoomNIDsForEventNIDs(ctx context.Context, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
|
SelectRoomNIDsForEventNIDs(ctx context.Context, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
|
||||||
|
SelectEventAuthEventNIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID][]types.EventNID, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Rooms interface {
|
type Rooms interface {
|
||||||
|
|
Loading…
Reference in a new issue