From c7430ec403b133adea0336424f1ebebd539dce6b Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Mon, 4 Jun 2018 16:00:04 +0100 Subject: [PATCH] Prevent sql scanning into nil value in accounts_table Signed-off-by: Andrew Morgan --- .../auth/storage/accounts/accounts_table.go | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/accounts_table.go b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/accounts_table.go index aaf6af39..7b7c997a 100644 --- a/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/accounts_table.go +++ b/src/github.com/matrix-org/dendrite/clientapi/auth/storage/accounts/accounts_table.go @@ -22,6 +22,8 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/gomatrixserverlib" + + log "github.com/sirupsen/logrus" ) const accountsSchema = ` @@ -121,14 +123,27 @@ func (s *accountsStatements) selectPasswordHash( func (s *accountsStatements) selectAccountByLocalpart( ctx context.Context, localpart string, -) (acc *authtypes.Account, err error) { +) (*authtypes.Account, error) { + var localpartPtr, appserviceIDPtr sql.NullString + var acc authtypes.Account + stmt := s.selectAccountByLocalpartStmt - err = stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &acc.AppServiceID) - if err == nil { - acc.UserID = userutil.MakeUserID(localpart, s.serverName) - acc.ServerName = s.serverName + err := stmt.QueryRowContext(ctx, localpart).Scan(&localpartPtr, &appserviceIDPtr) + if err != nil { + log.WithError(err).Error("Unable to retrieve user from the db") + return nil, err } - return + if appserviceIDPtr.Valid { + acc.AppServiceID = appserviceIDPtr.String + } + if localpartPtr.Valid { + acc.Localpart = localpartPtr.String + } + + acc.UserID = userutil.MakeUserID(localpart, s.serverName) + acc.ServerName = s.serverName + + return &acc, nil } func (s *accountsStatements) selectNewNumericLocalpart(