Merge branch 'master' into matthew/peeking

This commit is contained in:
Neil Alexander 2020-09-04 16:08:12 +01:00
commit 64fe2741c0
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
33 changed files with 289 additions and 101 deletions

View file

@ -130,7 +130,7 @@ func (m *DendriteMonolith) Start() {
asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI)
stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer)
fsAPI := federationsender.NewInternalAPI(
base, federation, rsAPI, stateAPI, keyRing,
base, federation, rsAPI, keyRing,
)
ygg.SetSessionFunc(func(address string) {

View file

@ -5,6 +5,7 @@ type LoginType string
// The relevant login types implemented in Dendrite
const (
LoginTypePassword = "m.login.password"
LoginTypeDummy = "m.login.dummy"
LoginTypeSharedSecret = "org.matrix.login.shared_secret"
LoginTypeRecaptcha = "m.login.recaptcha"

View file

@ -19,7 +19,6 @@ import (
"net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
currentstateAPI "github.com/matrix-org/dendrite/currentstateserver/api"
"github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/roomserver/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
@ -94,10 +93,10 @@ func GetMemberships(
func GetJoinedRooms(
req *http.Request,
device *userapi.Device,
stateAPI currentstateAPI.CurrentStateInternalAPI,
rsAPI api.RoomserverInternalAPI,
) util.JSONResponse {
var res currentstateAPI.QueryRoomsForUserResponse
err := stateAPI.QueryRoomsForUser(req.Context(), &currentstateAPI.QueryRoomsForUserRequest{
var res api.QueryRoomsForUserResponse
err := rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{
UserID: device.UserID,
WantMembership: "join",
}, &res)

View file

@ -0,0 +1,127 @@
package routing
import (
"net/http"
"github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
"github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/userapi/api"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
type newPasswordRequest struct {
NewPassword string `json:"new_password"`
LogoutDevices bool `json:"logout_devices"`
Auth newPasswordAuth `json:"auth"`
}
type newPasswordAuth struct {
Type string `json:"type"`
Session string `json:"session"`
auth.PasswordRequest
}
func Password(
req *http.Request,
userAPI userapi.UserInternalAPI,
accountDB accounts.Database,
device *api.Device,
cfg *config.ClientAPI,
) util.JSONResponse {
// Check that the existing password is right.
var r newPasswordRequest
r.LogoutDevices = true
// Unmarshal the request.
resErr := httputil.UnmarshalJSONRequest(req, &r)
if resErr != nil {
return *resErr
}
// Retrieve or generate the sessionID
sessionID := r.Auth.Session
if sessionID == "" {
// Generate a new, random session ID
sessionID = util.RandomString(sessionIDLength)
}
// Require password auth to change the password.
if r.Auth.Type != authtypes.LoginTypePassword {
return util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: newUserInteractiveResponse(
sessionID,
[]authtypes.Flow{
{
Stages: []authtypes.LoginType{authtypes.LoginTypePassword},
},
},
nil,
),
}
}
// Check if the existing password is correct.
typePassword := auth.LoginTypePassword{
GetAccountByPassword: accountDB.GetAccountByPassword,
Config: cfg,
}
if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil {
return *authErr
}
AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
// Check the new password strength.
if resErr = validatePassword(r.NewPassword); resErr != nil {
return *resErr
}
// Get the local part.
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
return jsonerror.InternalServerError()
}
// Ask the user API to perform the password change.
passwordReq := &userapi.PerformPasswordUpdateRequest{
Localpart: localpart,
Password: r.NewPassword,
}
passwordRes := &userapi.PerformPasswordUpdateResponse{}
if err := userAPI.PerformPasswordUpdate(req.Context(), passwordReq, passwordRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("PerformPasswordUpdate failed")
return jsonerror.InternalServerError()
}
if !passwordRes.PasswordUpdated {
util.GetLogger(req.Context()).Error("Expected password to have been updated but wasn't")
return jsonerror.InternalServerError()
}
// If the request asks us to log out all other devices then
// ask the user API to do that.
if r.LogoutDevices {
logoutReq := &userapi.PerformDeviceDeletionRequest{
UserID: device.UserID,
DeviceIDs: nil,
ExceptDeviceID: device.ID,
}
logoutRes := &userapi.PerformDeviceDeletionResponse{}
if err := userAPI.PerformDeviceDeletion(req.Context(), logoutReq, logoutRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed")
return jsonerror.InternalServerError()
}
}
// Return a success code.
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct{}{},
}
}

View file

@ -140,8 +140,8 @@ func SetAvatarURL(
return jsonerror.InternalServerError()
}
var res currentstateAPI.QueryRoomsForUserResponse
err = stateAPI.QueryRoomsForUser(req.Context(), &currentstateAPI.QueryRoomsForUserRequest{
var res api.QueryRoomsForUserResponse
err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{
UserID: device.UserID,
WantMembership: "join",
}, &res)
@ -258,8 +258,8 @@ func SetDisplayName(
return jsonerror.InternalServerError()
}
var res currentstateAPI.QueryRoomsForUserResponse
err = stateAPI.QueryRoomsForUser(req.Context(), &currentstateAPI.QueryRoomsForUserRequest{
var res api.QueryRoomsForUserResponse
err = rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{
UserID: device.UserID,
WantMembership: "join",
}, &res)

View file

@ -118,7 +118,7 @@ func Setup(
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/joined_rooms",
httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
return GetJoinedRooms(req, device, stateAPI)
return GetJoinedRooms(req, device, rsAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/join",
@ -428,6 +428,15 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/account/password",
httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.rateLimit(req); r != nil {
return *r
}
return Password(req, userAPI, accountDB, device, cfg)
}),
).Methods(http.MethodPost, http.MethodOptions)
// Stub endpoints required by Riot
r0mux.Handle("/login",

View file

@ -162,7 +162,7 @@ func main() {
)
asAPI := appservice.NewInternalAPI(&base.Base, userAPI, rsAPI)
fsAPI := federationsender.NewInternalAPI(
&base.Base, federation, rsAPI, stateAPI, keyRing,
&base.Base, federation, rsAPI, keyRing,
)
rsAPI.SetFederationSenderAPI(fsAPI)
provider := newPublicRoomsProvider(base.LibP2PPubsub, rsAPI, stateAPI)

View file

@ -115,7 +115,7 @@ func main() {
asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI)
stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer)
fsAPI := federationsender.NewInternalAPI(
base, federation, rsAPI, stateAPI, keyRing,
base, federation, rsAPI, keyRing,
)
ygg.SetSessionFunc(func(address string) {

View file

@ -31,7 +31,7 @@ func main() {
rsAPI := base.RoomserverHTTPClient()
fsAPI := federationsender.NewInternalAPI(
base, federation, rsAPI, base.CurrentStateAPIClient(), keyRing,
base, federation, rsAPI, keyRing,
)
federationsender.AddInternalRoutes(base.InternalAPIMux, fsAPI)

View file

@ -98,7 +98,7 @@ func main() {
stateAPI := currentstateserver.NewInternalAPI(&base.Cfg.CurrentStateServer, base.KafkaConsumer)
fsAPI := federationsender.NewInternalAPI(
base, federation, rsAPI, stateAPI, keyRing,
base, federation, rsAPI, keyRing,
)
if base.UseHTTPAPIs {
federationsender.AddInternalRoutes(base.InternalAPIMux, fsAPI)

View file

@ -210,7 +210,7 @@ func main() {
asQuery := appservice.NewInternalAPI(
base, userAPI, rsAPI,
)
fedSenderAPI := federationsender.NewInternalAPI(base, federation, rsAPI, stateAPI, &keyRing)
fedSenderAPI := federationsender.NewInternalAPI(base, federation, rsAPI, &keyRing)
rsAPI.SetFederationSenderAPI(fedSenderAPI)
p2pPublicRoomProvider := NewLibP2PPublicRoomsProvider(node, fedSenderAPI, federation)

View file

@ -21,22 +21,10 @@ import (
)
type CurrentStateInternalAPI interface {
// QueryRoomsForUser retrieves a list of room IDs matching the given query.
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error
// QueryBulkStateContent does a bulk query for state event content in the given rooms.
QueryBulkStateContent(ctx context.Context, req *QueryBulkStateContentRequest, res *QueryBulkStateContentResponse) error
}
type QueryRoomsForUserRequest struct {
UserID string
// The desired membership of the user. If this is the empty string then no rooms are returned.
WantMembership string
}
type QueryRoomsForUserResponse struct {
RoomIDs []string
}
type QueryBulkStateContentRequest struct {
// Returns state events in these rooms
RoomIDs []string

View file

@ -26,15 +26,6 @@ type CurrentStateInternalAPI struct {
DB storage.Database
}
func (a *CurrentStateInternalAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error {
roomIDs, err := a.DB.GetRoomsByMembership(ctx, req.UserID, req.WantMembership)
if err != nil {
return err
}
res.RoomIDs = roomIDs
return nil
}
func (a *CurrentStateInternalAPI) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkStateContentRequest, res *api.QueryBulkStateContentResponse) error {
events, err := a.DB.GetBulkStateContent(ctx, req.RoomIDs, req.StateTuples, req.AllowWildcards)
if err != nil {

View file

@ -50,18 +50,6 @@ type httpCurrentStateInternalAPI struct {
httpClient *http.Client
}
func (h *httpCurrentStateInternalAPI) QueryRoomsForUser(
ctx context.Context,
request *api.QueryRoomsForUserRequest,
response *api.QueryRoomsForUserResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryRoomsForUser")
defer span.Finish()
apiURL := h.apiURL + QueryRoomsForUserPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpCurrentStateInternalAPI) QueryBulkStateContent(
ctx context.Context,
request *api.QueryBulkStateContentRequest,

View file

@ -25,19 +25,6 @@ import (
)
func AddRoutes(internalAPIMux *mux.Router, intAPI api.CurrentStateInternalAPI) {
internalAPIMux.Handle(QueryRoomsForUserPath,
httputil.MakeInternalAPI("queryRoomsForUser", func(req *http.Request) util.JSONResponse {
request := api.QueryRoomsForUserRequest{}
response := api.QueryRoomsForUserResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.QueryRoomsForUser(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryBulkStateContentPath,
httputil.MakeInternalAPI("queryBulkStateContent", func(req *http.Request) util.JSONResponse {
request := api.QueryBulkStateContentRequest{}

View file

@ -20,12 +20,12 @@ import (
"fmt"
"github.com/Shopify/sarama"
stateapi "github.com/matrix-org/dendrite/currentstateserver/api"
"github.com/matrix-org/dendrite/federationsender/queue"
"github.com/matrix-org/dendrite/federationsender/storage"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
@ -36,7 +36,7 @@ type KeyChangeConsumer struct {
db storage.Database
queues *queue.OutgoingQueues
serverName gomatrixserverlib.ServerName
stateAPI stateapi.CurrentStateInternalAPI
rsAPI roomserverAPI.RoomserverInternalAPI
}
// NewKeyChangeConsumer creates a new KeyChangeConsumer. Call Start() to begin consuming from key servers.
@ -45,7 +45,7 @@ func NewKeyChangeConsumer(
kafkaConsumer sarama.Consumer,
queues *queue.OutgoingQueues,
store storage.Database,
stateAPI stateapi.CurrentStateInternalAPI,
rsAPI roomserverAPI.RoomserverInternalAPI,
) *KeyChangeConsumer {
c := &KeyChangeConsumer{
consumer: &internal.ContinualConsumer{
@ -57,7 +57,7 @@ func NewKeyChangeConsumer(
queues: queues,
db: store,
serverName: cfg.Matrix.ServerName,
stateAPI: stateAPI,
rsAPI: rsAPI,
}
c.consumer.ProcessMessage = c.onMessage
@ -92,8 +92,8 @@ func (t *KeyChangeConsumer) onMessage(msg *sarama.ConsumerMessage) error {
return nil
}
var queryRes stateapi.QueryRoomsForUserResponse
err = t.stateAPI.QueryRoomsForUser(context.Background(), &stateapi.QueryRoomsForUserRequest{
var queryRes roomserverAPI.QueryRoomsForUserResponse
err = t.rsAPI.QueryRoomsForUser(context.Background(), &roomserverAPI.QueryRoomsForUserRequest{
UserID: m.UserID,
WantMembership: "join",
}, &queryRes)

View file

@ -16,7 +16,6 @@ package federationsender
import (
"github.com/gorilla/mux"
stateapi "github.com/matrix-org/dendrite/currentstateserver/api"
"github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/federationsender/consumers"
"github.com/matrix-org/dendrite/federationsender/internal"
@ -42,7 +41,6 @@ func NewInternalAPI(
base *setup.BaseDendrite,
federation *gomatrixserverlib.FederationClient,
rsAPI roomserverAPI.RoomserverInternalAPI,
stateAPI stateapi.CurrentStateInternalAPI,
keyRing *gomatrixserverlib.KeyRing,
) api.FederationSenderInternalAPI {
cfg := &base.Cfg.FederationSender
@ -82,7 +80,7 @@ func NewInternalAPI(
logrus.WithError(err).Panic("failed to start typing server consumer")
}
keyConsumer := consumers.NewKeyChangeConsumer(
&base.Cfg.KeyServer, base.KafkaConsumer, queues, federationSenderDB, stateAPI,
&base.Cfg.KeyServer, base.KafkaConsumer, queues, federationSenderDB, rsAPI,
)
if err := keyConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start key server consumer")

View file

@ -112,11 +112,6 @@ type mockStateAPI struct {
rsAPI *mockRoomserverAPI
}
// QueryRoomsForUser retrieves a list of room IDs matching the given query.
func (s *mockStateAPI) QueryRoomsForUser(ctx context.Context, req *stateapi.QueryRoomsForUserRequest, res *stateapi.QueryRoomsForUserResponse) error {
return nil
}
// QueryBulkStateContent does a bulk query for state event content in the given rooms.
func (s *mockStateAPI) QueryBulkStateContent(ctx context.Context, req *stateapi.QueryBulkStateContentRequest, res *stateapi.QueryBulkStateContentResponse) error {
var res2 api.QueryBulkStateContentResponse

View file

@ -460,3 +460,8 @@ If user leaves room, remote user changes device and rejoins we see update in /sy
Can search public room list
Can get remote public room list
Asking for a remote rooms list, but supplying the local server's name, returns the local rooms list
After changing password, can't log in with old password
After changing password, can log in with new password
After changing password, existing session still works
After changing password, different sessions can optionally be kept
After changing password, a different session no longer works by default

View file

@ -26,6 +26,7 @@ import (
type UserInternalAPI interface {
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
@ -63,6 +64,10 @@ type PerformDeviceDeletionRequest struct {
UserID string
// The devices to delete. An empty slice means delete all devices.
DeviceIDs []string
// The requesting device ID to exclude from deletion. This is needed
// so that a password change doesn't cause that client to be logged
// out. Only specify when DeviceIDs is empty.
ExceptDeviceID string
}
type PerformDeviceDeletionResponse struct {
@ -165,6 +170,18 @@ type PerformAccountCreationResponse struct {
Account *Account
}
// PerformAccountCreationRequest is the request for PerformAccountCreation
type PerformPasswordUpdateRequest struct {
Localpart string // Required: The localpart for this account.
Password string // Required: The new password to set.
}
// PerformAccountCreationResponse is the response for PerformAccountCreation
type PerformPasswordUpdateResponse struct {
PasswordUpdated bool
Account *Account
}
// PerformDeviceCreationRequest is the request for PerformDeviceCreation
type PerformDeviceCreationRequest struct {
Localpart string

View file

@ -98,6 +98,15 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
res.Account = acc
return nil
}
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
return err
}
res.PasswordUpdated = true
return nil
}
func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error {
util.GetLogger(ctx).WithFields(logrus.Fields{
"localpart": req.Localpart,
@ -126,7 +135,7 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 {
var devices []api.Device
devices, err = a.DeviceDB.RemoveAllDevices(ctx, local)
devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
for _, d := range devices {
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
}

View file

@ -30,6 +30,7 @@ const (
PerformDeviceCreationPath = "/userapi/performDeviceCreation"
PerformAccountCreationPath = "/userapi/performAccountCreation"
PerformPasswordUpdatePath = "/userapi/performPasswordUpdate"
PerformDeviceDeletionPath = "/userapi/performDeviceDeletion"
PerformDeviceUpdatePath = "/userapi/performDeviceUpdate"
@ -81,6 +82,18 @@ func (h *httpUserInternalAPI) PerformAccountCreation(
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) PerformPasswordUpdate(
ctx context.Context,
request *api.PerformPasswordUpdateRequest,
response *api.PerformPasswordUpdateResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPasswordUpdate")
defer span.Finish()
apiURL := h.apiURL + PerformPasswordUpdatePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpUserInternalAPI) PerformDeviceCreation(
ctx context.Context,
request *api.PerformDeviceCreationRequest,

View file

@ -39,6 +39,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformAccountCreationPath,
httputil.MakeInternalAPI("performPasswordUpdate", func(req *http.Request) util.JSONResponse {
request := api.PerformPasswordUpdateRequest{}
response := api.PerformPasswordUpdateResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := s.PerformPasswordUpdate(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformDeviceCreationPath,
httputil.MakeInternalAPI("performDeviceCreation", func(req *http.Request) util.JSONResponse {
request := api.PerformDeviceCreationRequest{}

View file

@ -28,6 +28,7 @@ type Database interface {
internal.PartitionStorer
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
SetDisplayName(ctx context.Context, localpart string, displayName string) error
// CreateAccount makes a new account with the given login name and password, and creates an empty profile

View file

@ -47,6 +47,9 @@ CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
@ -56,10 +59,9 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" +
"SELECT nextval('numeric_username_seq')"
// TODO: Update password
type accountsStatements struct {
insertAccountStmt *sql.Stmt
updatePasswordStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
@ -74,6 +76,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
return
}
if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil {
return
}
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
return
}
@ -114,6 +119,13 @@ func (s *accountsStatements) insertAccount(
}, nil
}
func (s *accountsStatements) updatePassword(
ctx context.Context, localpart, passwordHash string,
) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
return
}
func (s *accountsStatements) selectPasswordHash(
ctx context.Context, localpart string,
) (hash string, err error) {

View file

@ -112,6 +112,17 @@ func (d *Database) SetDisplayName(
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
// SetPassword sets the account password to the given hash.
func (d *Database) SetPassword(
ctx context.Context, localpart, plaintextPassword string,
) error {
hash, err := hashPassword(plaintextPassword)
if err != nil {
return err
}
return d.accounts.updatePassword(ctx, localpart, hash)
}
// CreateGuestAccount makes a new guest account and creates an empty profile
// for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) {

View file

@ -45,6 +45,9 @@ CREATE TABLE IF NOT EXISTS account_accounts (
const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
@ -54,11 +57,10 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" +
"SELECT COUNT(localpart) FROM account_accounts"
// TODO: Update password
type accountsStatements struct {
db *sql.DB
insertAccountStmt *sql.Stmt
updatePasswordStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
@ -75,6 +77,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
return
}
if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil {
return
}
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
return
}
@ -115,6 +120,13 @@ func (s *accountsStatements) insertAccount(
}, nil
}
func (s *accountsStatements) updatePassword(
ctx context.Context, localpart, passwordHash string,
) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
return
}
func (s *accountsStatements) selectPasswordHash(
ctx context.Context, localpart string,
) (hash string, err error) {

View file

@ -126,6 +126,18 @@ func (d *Database) SetDisplayName(
})
}
// SetPassword sets the account password to the given hash.
func (d *Database) SetPassword(
ctx context.Context, localpart, plaintextPassword string,
) error {
hash, err := hashPassword(plaintextPassword)
if err != nil {
return err
}
err = d.accounts.updatePassword(ctx, localpart, hash)
return err
}
// CreateGuestAccount makes a new guest account and creates an empty profile
// for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) {

View file

@ -36,5 +36,5 @@ type Database interface {
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart string) (devices []api.Device, err error)
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
}

View file

@ -70,7 +70,7 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
@ -79,7 +79,7 @@ const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1"
"DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
@ -179,10 +179,10 @@ func (s *devicesStatements) deleteDevices(
// deleteDevicesByLocalpart removes all devices for the
// given user localpart.
func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart)
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
return err
}
@ -251,10 +251,10 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
}
func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart)
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
if err != nil {
return devices, err

View file

@ -68,7 +68,7 @@ func (d *Database) GetDeviceByID(
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart)
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@ -175,14 +175,14 @@ func (d *Database) RemoveDevices(
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string,
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart)
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil

View file

@ -59,7 +59,7 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
@ -68,7 +68,7 @@ const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1"
"DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
@ -182,10 +182,10 @@ func (s *devicesStatements) deleteDevices(
}
func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart)
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
return err
}
@ -231,10 +231,10 @@ func (s *devicesStatements) selectDeviceByID(
}
func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart)
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
if err != nil {
return devices, err

View file

@ -72,7 +72,7 @@ func (d *Database) GetDeviceByID(
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart)
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@ -179,14 +179,14 @@ func (d *Database) RemoveDevices(
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string,
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart)
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil