Refactor user API storage (#2202)

* Refactor User API database

* Fix migration bugs
This commit is contained in:
Neil Alexander 2022-02-18 13:51:59 +00:00 committed by GitHub
parent 9bd5e414c9
commit 9f4a39e8e0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 1165 additions and 1671 deletions

View file

@ -19,6 +19,7 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
)
@ -58,12 +59,13 @@ type threepidStatements struct {
deleteThreePIDStmt *sql.Stmt
}
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(threepidSchema)
func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
s := &threepidStatements{}
_, err := db.Exec(threepidSchema)
if err != nil {
return
return nil, err
}
return sqlutil.StatementList{
return s, sqlutil.StatementList{
{&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL},
{&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL},
{&s.insertThreePIDStmt, insertThreePIDSQL},
@ -71,7 +73,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db)
}
func (s *threepidStatements) selectLocalpartForThreePID(
func (s *threepidStatements) SelectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string,
) (localpart string, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
@ -82,7 +84,7 @@ func (s *threepidStatements) selectLocalpartForThreePID(
return
}
func (s *threepidStatements) selectThreePIDsForLocalpart(
func (s *threepidStatements) SelectThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
@ -106,7 +108,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
return
}
func (s *threepidStatements) insertThreePID(
func (s *threepidStatements) InsertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
@ -114,8 +116,9 @@ func (s *threepidStatements) insertThreePID(
return
}
func (s *threepidStatements) deleteThreePID(
ctx context.Context, threepid string, medium string) (err error) {
_, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium)
func (s *threepidStatements) DeleteThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) {
stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium)
return
}