Various bug fixes and tweaks around invites and membership

This commit is contained in:
Neil Alexander 2022-03-17 17:05:21 +00:00
parent 0fb94fc781
commit 4e64c270db
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
10 changed files with 36 additions and 47 deletions

View file

@ -52,7 +52,7 @@ const insertInviteEventSQL = "" +
") VALUES ($1, $2, $3, $4, FALSE) RETURNING id"
const deleteInviteEventSQL = "" +
"UPDATE syncapi_invite_events SET deleted=TRUE, id=nextval('syncapi_stream_id') WHERE event_id = $1 RETURNING id"
"UPDATE syncapi_invite_events SET deleted=TRUE, id=nextval('syncapi_stream_id') WHERE event_id = $1 AND deleted=FALSE RETURNING id"
const selectInviteEventsInRangeSQL = "" +
"SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" +

View file

@ -47,7 +47,7 @@ const insertInviteEventSQL = "" +
" VALUES ($1, $2, $3, $4, $5, false)"
const deleteInviteEventSQL = "" +
"UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2"
"UPDATE syncapi_invite_events SET deleted=true, id=$1 WHERE event_id = $2 AND deleted=false"
const selectInviteEventsInRangeSQL = "" +
"SELECT room_id, headered_event_json, deleted FROM syncapi_invite_events" +

View file

@ -27,15 +27,12 @@ INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("invite", 0)
`
const increaseStreamIDStmt = "" +
"UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1"
const selectStreamIDStmt = "" +
"SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1"
"UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" +
" RETURNING stream_id"
type streamIDStatements struct {
db *sql.DB
increaseStreamIDStmt *sql.Stmt
selectStreamIDStmt *sql.Stmt
}
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
@ -47,48 +44,29 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
if s.increaseStreamIDStmt, err = db.Prepare(increaseStreamIDStmt); err != nil {
return
}
if s.selectStreamIDStmt, err = db.Prepare(selectStreamIDStmt); err != nil {
return
}
return
}
func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil {
return
}
err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos)
err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos)
return
}
func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
if _, err = increaseStmt.ExecContext(ctx, "receipt"); err != nil {
return
}
err = selectStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
return
}
func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
if _, err = increaseStmt.ExecContext(ctx, "invite"); err != nil {
return
}
err = selectStmt.QueryRowContext(ctx, "invite").Scan(&pos)
err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos)
return
}
func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
if _, err = increaseStmt.ExecContext(ctx, "accountdata"); err != nil {
return
}
err = selectStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
return
}