Use returned ID from INSERT in create filter (#297)

This commit is contained in:
Erik Johnston 2017-10-10 15:17:29 +01:00 committed by GitHub
parent f6bda82366
commit c0271c2462
2 changed files with 5 additions and 24 deletions

View file

@ -15,8 +15,8 @@
package accounts
import (
"database/sql"
"context"
"database/sql"
)
const filterSchema = `
@ -41,13 +41,9 @@ const selectFilterSQL = "" +
const insertFilterSQL = "" +
"INSERT INTO account_filter (filter, id, localpart) VALUES ($1, DEFAULT, $2) RETURNING id"
const findMaxIDSQL = "" +
"SELECT MAX(id) FROM account_filter WHERE localpart = $1"
type filterStatements struct {
selectFilterStmt *sql.Stmt
insertFilterStmt *sql.Stmt
findMaxIDStmt *sql.Stmt
}
func (s *filterStatements) prepare(db *sql.DB) (err error) {
@ -61,10 +57,6 @@ func (s *filterStatements) prepare(db *sql.DB) (err error) {
if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil {
return
}
if s.findMaxIDStmt, err = db.Prepare(findMaxIDSQL); err != nil {
return
}
return
}
@ -77,14 +69,7 @@ func (s *filterStatements) selectFilter(
func (s *filterStatements) insertFilter(
ctx context.Context, filter string, localpart string,
) (err error) {
_, err = s.insertFilterStmt.ExecContext(ctx, filter, localpart)
return
}
func (s *filterStatements) findMaxID(
ctx context.Context, localpart string,
) (id string, err error) {
err = s.findMaxIDStmt.QueryRowContext(ctx, localpart).Scan(&id)
) (pos string, err error) {
err = s.insertFilterStmt.QueryRowContext(ctx, filter, localpart).Scan(&pos)
return
}

View file

@ -29,7 +29,7 @@ import (
// Database represents an account database
type Database struct {
db *sql.DB
db *sql.DB
common.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
@ -333,11 +333,7 @@ func (d *Database) GetFilter(
func (d *Database) PutFilter(
ctx context.Context, localpart, filter string,
) (string, error) {
err := d.filter.insertFilter(ctx, filter, localpart)
if err != nil {
return "", err
}
return d.filter.findMaxID(ctx, localpart)
return d.filter.insertFilter(ctx, filter, localpart)
}
// CheckAccountAvailability checks if the username/localpart is already present in the database.