Use returned ID from INSERT in create filter (#297)

This commit is contained in:
Erik Johnston 2017-10-10 15:17:29 +01:00 committed by GitHub
parent f6bda82366
commit c0271c2462
2 changed files with 5 additions and 24 deletions

View file

@ -15,8 +15,8 @@
package accounts package accounts
import ( import (
"database/sql"
"context" "context"
"database/sql"
) )
const filterSchema = ` const filterSchema = `
@ -41,13 +41,9 @@ const selectFilterSQL = "" +
const insertFilterSQL = "" + const insertFilterSQL = "" +
"INSERT INTO account_filter (filter, id, localpart) VALUES ($1, DEFAULT, $2) RETURNING id" "INSERT INTO account_filter (filter, id, localpart) VALUES ($1, DEFAULT, $2) RETURNING id"
const findMaxIDSQL = "" +
"SELECT MAX(id) FROM account_filter WHERE localpart = $1"
type filterStatements struct { type filterStatements struct {
selectFilterStmt *sql.Stmt selectFilterStmt *sql.Stmt
insertFilterStmt *sql.Stmt insertFilterStmt *sql.Stmt
findMaxIDStmt *sql.Stmt
} }
func (s *filterStatements) prepare(db *sql.DB) (err error) { func (s *filterStatements) prepare(db *sql.DB) (err error) {
@ -61,10 +57,6 @@ func (s *filterStatements) prepare(db *sql.DB) (err error) {
if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil { if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil {
return return
} }
if s.findMaxIDStmt, err = db.Prepare(findMaxIDSQL); err != nil {
return
}
return return
} }
@ -77,14 +69,7 @@ func (s *filterStatements) selectFilter(
func (s *filterStatements) insertFilter( func (s *filterStatements) insertFilter(
ctx context.Context, filter string, localpart string, ctx context.Context, filter string, localpart string,
) (err error) { ) (pos string, err error) {
_, err = s.insertFilterStmt.ExecContext(ctx, filter, localpart) err = s.insertFilterStmt.QueryRowContext(ctx, filter, localpart).Scan(&pos)
return
}
func (s *filterStatements) findMaxID(
ctx context.Context, localpart string,
) (id string, err error) {
err = s.findMaxIDStmt.QueryRowContext(ctx, localpart).Scan(&id)
return return
} }

View file

@ -29,7 +29,7 @@ import (
// Database represents an account database // Database represents an account database
type Database struct { type Database struct {
db *sql.DB db *sql.DB
common.PartitionOffsetStatements common.PartitionOffsetStatements
accounts accountsStatements accounts accountsStatements
profiles profilesStatements profiles profilesStatements
@ -333,11 +333,7 @@ func (d *Database) GetFilter(
func (d *Database) PutFilter( func (d *Database) PutFilter(
ctx context.Context, localpart, filter string, ctx context.Context, localpart, filter string,
) (string, error) { ) (string, error) {
err := d.filter.insertFilter(ctx, filter, localpart) return d.filter.insertFilter(ctx, filter, localpart)
if err != nil {
return "", err
}
return d.filter.findMaxID(ctx, localpart)
} }
// CheckAccountAvailability checks if the username/localpart is already present in the database. // CheckAccountAvailability checks if the username/localpart is already present in the database.