290 lines
8.5 KiB
Go
290 lines
8.5 KiB
Go
package user
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
|
|
"git.nutfactory.org/hoernschen/Matrix/config"
|
|
"git.nutfactory.org/hoernschen/Matrix/entities/device"
|
|
"git.nutfactory.org/hoernschen/Matrix/utils"
|
|
)
|
|
|
|
func New(username string, name string, password string) (err error, newUser *User) {
|
|
err, hashedPassword := utils.Hash([]byte(password))
|
|
if err != nil {
|
|
return
|
|
}
|
|
id := generateUserId(username)
|
|
newUser = &User{
|
|
Id: id,
|
|
Name: name,
|
|
Password: hashedPassword,
|
|
Devices: make(map[string]*device.Device),
|
|
}
|
|
return
|
|
}
|
|
|
|
func CheckUsernameAvailabilityHandler(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
|
request := availableRequest{}
|
|
errResponse := utils.CheckRequest(r)
|
|
if errResponse != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(errResponse); err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
}
|
|
decoder := json.NewDecoder(r.Body)
|
|
err := decoder.Decode(&request)
|
|
if err != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Could not parse JSON: %s", err)}); err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
}
|
|
userId := generateUserId(request.Username)
|
|
foundUser, err := ReadUser(userId)
|
|
if err != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}); err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
}
|
|
if foundUser != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorCode: "M_USER_IN_USE", ErrorMessage: "Username already in use"}); err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
if err := json.NewEncoder(w).Encode(availableResponse{Available: true}); err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
func RegisterHandler(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
|
request := RegisterRequest{}
|
|
errResponse := utils.CheckRequest(r)
|
|
if errResponse != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(errResponse); err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
}
|
|
decoder := json.NewDecoder(r.Body)
|
|
err := decoder.Decode(&request)
|
|
if err != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Could not parse JSON: %s", err)}); err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
}
|
|
errResponse = checkLoginType(request.Auth.LoginType)
|
|
if errResponse != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(errResponse); err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
}
|
|
err, newUser := New(request.Username, request.Username, request.Password)
|
|
foundUser, err := ReadUser(newUser.Id)
|
|
if foundUser != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorCode: "M_USER_IN_USE", ErrorMessage: "Username already in use"}); err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
}
|
|
err = CreateUser(newUser)
|
|
if err != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}); err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
}
|
|
userDevice, errResponse := createUserDevice(request.DeviceId, request.DeviceName, newUser.Id)
|
|
if errResponse != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(errResponse); err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
}
|
|
response := RegisterResponse{
|
|
UserId: newUser.Id,
|
|
AccessToken: userDevice.AccessToken,
|
|
DeviceId: userDevice.Id,
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
func LoginHandler(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
|
request := loginRequest{}
|
|
errResponse := utils.CheckRequest(r)
|
|
if errResponse != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(errResponse); err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
}
|
|
decoder := json.NewDecoder(r.Body)
|
|
err := decoder.Decode(&request)
|
|
if err != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Could not parse JSON: %s", err)}); err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
}
|
|
errResponse = checkLoginType(request.LoginType)
|
|
if errResponse != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(errResponse); err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
}
|
|
if request.Identifier.IdentifierType != "m.id.user" && request.Identifier.User == "" {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: "Username missing"}); err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
}
|
|
userId := generateUserId(request.Identifier.User)
|
|
foundUser, err := ReadUser(userId)
|
|
if err != nil || foundUser == nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorCode: "M_FORBIDDEN"}); err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
}
|
|
err, hashedPassword := utils.Hash([]byte(request.Password))
|
|
if err != nil || foundUser.Password != hashedPassword {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorCode: "M_FORBIDDEN"}); err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
}
|
|
userDevice, errResponse := createUserDevice(request.DeviceId, request.DeviceName, request.Identifier.User)
|
|
if errResponse != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(errResponse); err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
}
|
|
response := loginResponse{
|
|
UserId: foundUser.Id,
|
|
AccessToken: userDevice.AccessToken,
|
|
DeviceId: userDevice.Id,
|
|
}
|
|
response.DiscoveryInfo.Homeserver.BaseUrl = config.Homeserver
|
|
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
|
w.WriteHeader(http.StatusOK)
|
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
func LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
|
errResponse := utils.CheckRequest(r)
|
|
if errResponse != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(errResponse); err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
}
|
|
accessToken, errResponse := utils.GetAccessToken(r)
|
|
if errResponse != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(errResponse); err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
}
|
|
foundDevice, err := device.ReadDeviceFromAccessToken(accessToken)
|
|
if err != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}); err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
}
|
|
foundDevice.AccessToken = ""
|
|
err = device.UpdateDevice(foundDevice)
|
|
if err != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}); err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
func checkLoginType(loginType string) (errResponse *utils.ErrorResponse) {
|
|
if loginType != "m.login.password" {
|
|
errResponse = &utils.ErrorResponse{ErrorCode: "M_UNKNOWN", ErrorMessage: "Bad login type."}
|
|
return
|
|
}
|
|
return
|
|
}
|
|
|
|
func generateUserId(username string) string {
|
|
return fmt.Sprintf("@%s:%s", username, config.Homeserver)
|
|
}
|
|
|
|
func createUserDevice(id string, name string, userId string) (userDevice *device.Device, errResponse *utils.ErrorResponse) {
|
|
userDevice, err := device.ReadDevice(id)
|
|
if err != nil {
|
|
errResponse = &utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}
|
|
return
|
|
}
|
|
if userDevice != nil {
|
|
err = userDevice.RenewAccesToken()
|
|
if name != "" {
|
|
userDevice.Name = name
|
|
}
|
|
if err != nil {
|
|
errResponse = &utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Unable to renew AccesToken: %s", err)}
|
|
return
|
|
}
|
|
err = device.UpdateDevice(userDevice)
|
|
if err != nil {
|
|
errResponse = &utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}
|
|
return
|
|
}
|
|
} else {
|
|
err, userDevice = device.New(name)
|
|
if err != nil {
|
|
errResponse = &utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Unable to create device: %s", err)}
|
|
return
|
|
}
|
|
err = device.CreateDevice(userDevice, userId)
|
|
if err != nil {
|
|
errResponse = &utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}
|
|
return
|
|
}
|
|
}
|
|
return
|
|
}
|