Set display_name and/or avatar_url for server notices (#2820)

This should fix #2815 by making sure we actually set the `display_name`
and/or `avatar_url` and create the needed membership event.
To avoid creating a new membership event when starting Dendrite,
`SetAvatarURL` and `SetDisplayName` now return a `Changed` value, which
also makes the regular endpoints idempotent.
This commit is contained in:
Till 2022-10-21 10:48:25 +02:00 committed by GitHub
parent 40cfb9a4ea
commit e57b301722
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 191 additions and 132 deletions

View file

@ -96,7 +96,7 @@ type ClientUserAPI interface {
PerformAccountDeactivation(ctx context.Context, req *PerformAccountDeactivationRequest, res *PerformAccountDeactivationResponse) error
PerformOpenIDTokenCreation(ctx context.Context, req *PerformOpenIDTokenCreationRequest, res *PerformOpenIDTokenCreationResponse) error
SetAvatarURL(ctx context.Context, req *PerformSetAvatarURLRequest, res *PerformSetAvatarURLResponse) error
SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error
SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error
QueryNotifications(ctx context.Context, req *QueryNotificationsRequest, res *QueryNotificationsResponse) error
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformKeyBackup(ctx context.Context, req *PerformKeyBackupRequest, res *PerformKeyBackupResponse) error
@ -579,7 +579,10 @@ type Notification struct {
type PerformSetAvatarURLRequest struct {
Localpart, AvatarURL string
}
type PerformSetAvatarURLResponse struct{}
type PerformSetAvatarURLResponse struct {
Profile *authtypes.Profile `json:"profile"`
Changed bool `json:"changed"`
}
type QueryNumericLocalpartResponse struct {
ID int64
@ -606,6 +609,11 @@ type PerformUpdateDisplayNameRequest struct {
Localpart, DisplayName string
}
type PerformUpdateDisplayNameResponse struct {
Profile *authtypes.Profile `json:"profile"`
Changed bool `json:"changed"`
}
type QueryLocalpartForThreePIDRequest struct {
ThreePID, Medium string
}

View file

@ -168,7 +168,7 @@ func (t *UserInternalAPITrace) QueryAccountAvailability(ctx context.Context, req
return err
}
func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *struct{}) error {
func (t *UserInternalAPITrace) SetDisplayName(ctx context.Context, req *PerformUpdateDisplayNameRequest, res *PerformUpdateDisplayNameResponse) error {
err := t.Impl.SetDisplayName(ctx, req, res)
util.GetLogger(ctx).Infof("SetDisplayName req=%+v res=%+v", js(req), js(res))
return err

View file

@ -170,7 +170,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return nil
}
if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
return err
}
@ -813,7 +813,10 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush
}
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
return a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL)
profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL)
res.Profile = profile
res.Changed = changed
return err
}
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
@ -847,8 +850,11 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q
}
}
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, _ *struct{}) error {
return a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName)
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error {
profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName)
res.Profile = profile
res.Changed = changed
return err
}
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {

View file

@ -388,7 +388,7 @@ func (h *httpUserInternalAPI) QueryAccountByPassword(
func (h *httpUserInternalAPI) SetDisplayName(
ctx context.Context,
request *api.PerformUpdateDisplayNameRequest,
response *struct{},
response *api.PerformUpdateDisplayNameResponse,
) error {
return httputil.CallInternalRPCAPI(
"SetDisplayName", h.apiURL+PerformSetDisplayNamePath,

View file

@ -29,8 +29,8 @@ import (
type Profile interface {
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
SetDisplayName(ctx context.Context, localpart string, displayName string) error
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, bool, error)
}
type Account interface {

View file

@ -44,10 +44,18 @@ const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
const setAvatarURLSQL = "" +
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2"
"UPDATE userapi_profiles AS new" +
" SET avatar_url = $1" +
" FROM userapi_profiles AS old" +
" WHERE new.localpart = $2" +
" RETURNING new.display_name, old.avatar_url <> new.avatar_url"
const setDisplayNameSQL = "" +
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2"
"UPDATE userapi_profiles AS new" +
" SET display_name = $1" +
" FROM userapi_profiles AS old" +
" WHERE new.localpart = $2" +
" RETURNING new.avatar_url, old.display_name <> new.display_name"
const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
@ -100,16 +108,28 @@ func (s *profilesStatements) SelectProfileByLocalpart(
func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
) (err error) {
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
return
) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
Localpart: localpart,
AvatarURL: avatarURL,
}
var changed bool
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName, &changed)
return profile, changed, err
}
func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
) (err error) {
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
return
) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
Localpart: localpart,
DisplayName: displayName,
}
var changed bool
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL, &changed)
return profile, changed, err
}
func (s *profilesStatements) SelectProfilesBySearch(

View file

@ -96,20 +96,24 @@ func (d *Database) GetProfileByLocalpart(
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
) (profile *authtypes.Profile, changed bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
return err
})
return
}
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
) (profile *authtypes.Profile, changed bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
return err
})
return
}
// SetPassword sets the account password to the given hash.

View file

@ -44,10 +44,12 @@ const selectProfileByLocalpartSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
const setAvatarURLSQL = "" +
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2"
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" +
" RETURNING display_name"
const setDisplayNameSQL = "" +
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2"
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" +
" RETURNING avatar_url"
const selectProfilesBySearchSQL = "" +
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
@ -102,18 +104,40 @@ func (s *profilesStatements) SelectProfileByLocalpart(
func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
) (err error) {
) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
Localpart: localpart,
AvatarURL: avatarURL,
}
old, err := s.SelectProfileByLocalpart(ctx, localpart)
if err != nil {
return old, false, err
}
if old.AvatarURL == avatarURL {
return old, false, nil
}
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
_, err = stmt.ExecContext(ctx, avatarURL, localpart)
return
err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName)
return profile, true, err
}
func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
) (err error) {
) (*authtypes.Profile, bool, error) {
profile := &authtypes.Profile{
Localpart: localpart,
DisplayName: displayName,
}
old, err := s.SelectProfileByLocalpart(ctx, localpart)
if err != nil {
return old, false, err
}
if old.DisplayName == displayName {
return old, false, nil
}
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
_, err = stmt.ExecContext(ctx, displayName, localpart)
return
err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL)
return profile, true, err
}
func (s *profilesStatements) SelectProfilesBySearch(

View file

@ -382,15 +382,23 @@ func Test_Profile(t *testing.T) {
// set avatar & displayname
wantProfile.DisplayName = "Alice"
wantProfile.AvatarURL = "mxc://aliceAvatar"
err = db.SetDisplayName(ctx, aliceLocalpart, "Alice")
assert.NoError(t, err, "unable to set displayname")
err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url")
// verify profile
gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart)
assert.NoError(t, err, "unable to get profile by localpart")
gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, "Alice")
assert.Equal(t, wantProfile, gotProfile)
assert.NoError(t, err, "unable to set displayname")
assert.True(t, changed)
wantProfile.AvatarURL = "mxc://aliceAvatar"
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url")
assert.Equal(t, wantProfile, gotProfile)
assert.True(t, changed)
// Setting the same avatar again doesn't change anything
wantProfile.AvatarURL = "mxc://aliceAvatar"
gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url")
assert.Equal(t, wantProfile, gotProfile)
assert.False(t, changed)
// search profiles
searchRes, err := db.SearchProfiles(ctx, "Alice", 2)

View file

@ -84,8 +84,8 @@ type OpenIDTable interface {
type ProfileTable interface {
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error
SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error)
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error)
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, bool, error)
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, bool, error)
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
}

View file

@ -23,13 +23,14 @@ import (
"time"
"github.com/gorilla/mux"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/dendrite/userapi"
"github.com/matrix-org/dendrite/userapi/inthttp"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
@ -83,10 +84,10 @@ func TestQueryProfile(t *testing.T) {
if err != nil {
t.Fatalf("failed to make account: %s", err)
}
if err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil {
if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil {
t.Fatalf("failed to set avatar url: %s", err)
}
if err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil {
if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil {
t.Fatalf("failed to set display name: %s", err)
}