From f5cf2418776f0117455eb4156eabdc6dbcf9c1ee Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 2 Mar 2021 10:43:25 +0000 Subject: [PATCH] Fix user registration bug (#1777) --- userapi/internal/api.go | 2 +- userapi/storage/accounts/postgres/storage.go | 12 +++++---- .../storage/accounts/sqlite3/constraint.go | 27 ------------------- userapi/storage/accounts/sqlite3/storage.go | 14 +++++----- 4 files changed, 15 insertions(+), 40 deletions(-) delete mode 100644 userapi/storage/accounts/sqlite3/constraint.go diff --git a/userapi/internal/api.go b/userapi/internal/api.go index cf588a40..d8af5433 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -87,7 +87,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P ServerName: a.ServerName, UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), } - return nil + return err } if err = a.AccountDB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index 870756d8..e6adbfd8 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -170,8 +170,8 @@ func (d *Database) CreateAccount( func (d *Database) createAccount( ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, ) (*api.Account, error) { + var account *api.Account var err error - // Generate a password hash if this is not a password-less user hash := "" if plaintextPassword != "" { @@ -180,14 +180,16 @@ func (d *Database) createAccount( return nil, err } } - if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil { + if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil { if sqlutil.IsUniqueConstraintViolationErr(err) { return nil, sqlutil.ErrUserExists } return nil, err } - - if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ + if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil { + return nil, err + } + if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ "global": { "content": [], "override": [], @@ -198,7 +200,7 @@ func (d *Database) createAccount( }`)); err != nil { return nil, err } - return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) + return account, nil } // SaveAccountData saves new account data for a given user and a given room. diff --git a/userapi/storage/accounts/sqlite3/constraint.go b/userapi/storage/accounts/sqlite3/constraint.go deleted file mode 100644 index 32f96c8e..00000000 --- a/userapi/storage/accounts/sqlite3/constraint.go +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright 2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build !wasm - -package sqlite3 - -import ( - "errors" - - "github.com/mattn/go-sqlite3" -) - -func isConstraintError(err error) bool { - return errors.Is(err, sqlite3.ErrConstraint) -} diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 92c1c669..747be34f 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -204,6 +204,7 @@ func (d *Database) createAccount( ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, ) (*api.Account, error) { var err error + var account *api.Account // Generate a password hash if this is not a password-less user hash := "" if plaintextPassword != "" { @@ -212,14 +213,13 @@ func (d *Database) createAccount( return nil, err } } - if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil { - if isConstraintError(err) { - return nil, sqlutil.ErrUserExists - } + if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil { + return nil, sqlutil.ErrUserExists + } + if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil { return nil, err } - - if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ + if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{ "global": { "content": [], "override": [], @@ -230,7 +230,7 @@ func (d *Database) createAccount( }`)); err != nil { return nil, err } - return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) + return account, nil } // SaveAccountData saves new account data for a given user and a given room.