diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..6f747c1 --- /dev/null +++ b/config/config.go @@ -0,0 +1,21 @@ +package config + +var ServerName string = "Hoernschen's Matrix Server" +var Version string = "0.1" + +var Homeserver string +var Port string + +var PrivateKey []byte +var PublicKey []byte +var KeyId string +var VerifyKeys map[string]map[string][]byte + +// Parameters for Mesurements +// TODO: Implement correctly +var Packetloss float32 +var UnavailableTill int +var AuthentificationCheck bool +var Signing bool +var Encryption bool +var HttpString string diff --git a/entities/device/deviceDatabaseConnector.go b/entities/device/deviceDatabaseConnector.go index 896bddb..5ce0fb1 100644 --- a/entities/device/deviceDatabaseConnector.go +++ b/entities/device/deviceDatabaseConnector.go @@ -2,15 +2,16 @@ package device import ( "fmt" + "log" "nutfactory.org/Matrix/utils/database" ) func CreateDevice(device *Device, userId string) (err error) { sqlStmt := fmt.Sprintf(`INSERT INTO device - (id, name, userId) + (id, name, accessToken, userId) VALUES - (?, ?, ?)`) + (?, ?, ?, ?)`) tx, err := database.DB.Begin() if err != nil { @@ -23,7 +24,7 @@ func CreateDevice(device *Device, userId string) (err error) { } defer stmt.Close() - _, err = stmt.Exec(device.Id, device.Name, userId) + _, err = stmt.Exec(device.Id, device.Name, device.AccessToken, userId) if err != nil { return } @@ -32,7 +33,7 @@ func CreateDevice(device *Device, userId string) (err error) { } func ReadDevice(id string) (foundDevice *Device, err error) { - queryStmt := fmt.Sprintf(`SELECT id, name + queryStmt := fmt.Sprintf(`SELECT id, name, accessToken FROM device WHERE id = '%s'`, id) @@ -45,7 +46,31 @@ func ReadDevice(id string) (foundDevice *Device, err error) { if rows.Next() { foundDevice = &Device{} - err = rows.Scan(&foundDevice.Id, &foundDevice.Name) + err = rows.Scan(&foundDevice.Id, &foundDevice.Name, &foundDevice.AccessToken) + if err != nil { + return + } + foundDevice.Keys, err = ReadKeysForDevice(foundDevice.Id) + } + + return +} + +func ReadDeviceFromAccessToken(accessToken string) (foundDevice *Device, err error) { + queryStmt := fmt.Sprintf(`SELECT id, name, accessToken + FROM device + WHERE accessToken = '%s'`, accessToken) + log.Printf(queryStmt) + rows, err := database.DB.Query(queryStmt) + if err != nil { + return + } + + defer rows.Close() + + if rows.Next() { + foundDevice = &Device{} + err = rows.Scan(&foundDevice.Id, &foundDevice.Name, &foundDevice.AccessToken) if err != nil { return } @@ -56,7 +81,7 @@ func ReadDevice(id string) (foundDevice *Device, err error) { } func ReadDevicesForUser(userId string) (devices map[string]*Device, err error) { - queryStmt := fmt.Sprintf(`SELECT id, name + queryStmt := fmt.Sprintf(`SELECT id, name, accessToken FROM device WHERE userId = '%s'`, userId) @@ -71,7 +96,7 @@ func ReadDevicesForUser(userId string) (devices map[string]*Device, err error) { for rows.Next() { foundDevice := &Device{} - err = rows.Scan(&foundDevice.Id, &foundDevice.Name) + err = rows.Scan(&foundDevice.Id, &foundDevice.Name, &foundDevice.AccessToken) if err != nil { return } @@ -82,10 +107,10 @@ func ReadDevicesForUser(userId string) (devices map[string]*Device, err error) { return } -func UpdateDevice(device *Device, userId string) (err error) { +func UpdateDevice(device *Device) (err error) { sqlStmt := fmt.Sprintf(`UPDATE device SET name = ?, - userId = ? + accessToken = ? WHERE id = ?`) tx, err := database.DB.Begin() @@ -99,7 +124,7 @@ func UpdateDevice(device *Device, userId string) (err error) { } defer stmt.Close() - _, err = stmt.Exec(device.Name, userId, device.Id) + _, err = stmt.Exec(device.Name, device.AccessToken, device.Id) if err != nil { return } diff --git a/entities/device/key.go b/entities/device/key.go index 78153e0..b737f03 100644 --- a/entities/device/key.go +++ b/entities/device/key.go @@ -3,5 +3,18 @@ package device type Key struct { Id string `json:"id,omitempty"` Type string `json:"type,omitempty"` - Key string `json:"key,omitempty"` + Key []byte `json:"key,omitempty"` +} + +type serverKeys struct { + ServerName string `json:"server_name,omitempty"` + VerifyKeys map[string]verifyKey `json:"verify_keys,omitempty"` + OldVerifyKeys map[string]verifyKey `json:"old_verify_keys,omitempty"` + Signatures map[string]map[string]string `json:"signatures,omitempty"` + ValidUntil int64 `json:"valid_until_ts,omitempty"` +} + +type verifyKey struct { + Key string `json:"key,omitempty"` + Expired int64 `json:"expired_ts,omitempty"` } diff --git a/entities/device/keyController.go b/entities/device/keyController.go new file mode 100644 index 0000000..a51d4f5 --- /dev/null +++ b/entities/device/keyController.go @@ -0,0 +1,76 @@ +package device + +import ( + "encoding/json" + "fmt" + "net/http" + + "nutfactory.org/Matrix/config" + "nutfactory.org/Matrix/utils" +) + +func InitServerSigningKey() (err error) { + publicKey, privateKey, err := utils.GenerateKeyPair() + if err != nil { + return + } + config.PublicKey = publicKey + config.PrivateKey = privateKey + config.KeyId = "ed25519:1" + return +} + +func GetServerSigningKeyHandler(w http.ResponseWriter, r *http.Request) { + if config.PublicKey == nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: "Server Signing Key Missing"}); err != nil { + panic(err) + } + return + } + response := serverKeys{ + ServerName: config.Homeserver, + VerifyKeys: make(map[string]verifyKey), + } + response.VerifyKeys[config.KeyId] = verifyKey{Key: string(config.PublicKey)} + content, err := json.Marshal(response) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Error creating Signatures: %s", err)}); err != nil { + panic(err) + } + return + } + + response.Signatures = utils.SignContent(content) + + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(response); err != nil { + panic(err) + } +} + +func getVerifyKey(server string, id string) (key []byte, err error) { + if val, ok := config.VerifyKeys[server][id]; ok { + key = val + } else { + httpString := "https" + requestUrl := fmt.Sprintf("%s://%s/_matrix/key/v2/server", httpString, server) + var res *http.Response + res, err = http.Get(requestUrl) + if err != nil { + return + } + serverKeyRes := serverKeys{} + decoder := json.NewDecoder(res.Body) + err = decoder.Decode(&serverKeyRes) + config.VerifyKeys[server] = make(map[string][]byte) + for keyId, verifyKey := range serverKeyRes.VerifyKeys { + config.VerifyKeys[server][keyId] = []byte(verifyKey.Key) + if id == keyId { + key = []byte(verifyKey.Key) + } + } + } + return +} diff --git a/entities/event/edu.go b/entities/event/edu.go deleted file mode 100644 index 1ffed69..0000000 --- a/entities/event/edu.go +++ /dev/null @@ -1,8 +0,0 @@ -package event - -// TODO: Check if it can be deleted - -type EDU struct { - Type string `json:"type,omitempty"` - Content string `json:"content,omitempty"` -} diff --git a/entities/event/event.go b/entities/event/event.go index d6047b5..8d72227 100644 --- a/entities/event/event.go +++ b/entities/event/event.go @@ -1,10 +1,123 @@ package event type Event struct { - Id string `json:"id,omitempty"` - RoomId string `json:"roomId,omitempty"` - EventType string `json:"eventType,omitempty"` - Content string `json:"content,omitempty"` - ParentId string `json:"parent,omitempty"` - Depth int `json:"depth,omitempty"` + Id string `json:"event_id,omitempty"` + RoomId string `json:"room_id,omitempty"` + Sender string `json:"sender,omitempty"` + Origin string `json:"origin,omitempty"` + Timestamp int64 `json:"origin_server_ts,omitempty"` + EventType string `json:"type,omitempty"` + StateKey string `json:"state_key,omitempty"` + Content string `json:"content,omitempty"` + PrevEventHashes map[string]EventHash `json:"prev_events,omitempty"` + Depth int `json:"depth,omitempty"` + AuthEventHashes map[string]EventHash `json:"auth_events,omitempty"` + Unsigned UnsignedData `json:"unsigned,omitempty"` + Hashes EventHash `json:"hashes,omitempty"` + Signatures map[string]map[string]string `json:"signatures,omitempty"` +} + +type StateEvent struct { + EventType string `json:"type,omitempty"` + StateKey string `json:"state_key,omitempty"` + Content string `json:"content,omitempty"` +} + +type UnsignedData struct { + Age int `json:"age,omitempty"` + TransactionId string `json:"transaction_id,omitempty"` + ReplaceState string `json:"replaces_state,omitempty"` + PrevSender string `json:"prev_sender,omitempty"` + PrevContent string `json:"prev_content,omitempty"` + RedactedBecause string `json:"redacted_because,omitempty"` +} + +type EventHash struct { + SHA256 string `json:"sha256,omitempty"` +} + +type CreateEventContent struct { + Creator string `json:"creator,omitempty"` + Federated bool `json:"m.federate,omitempty"` + RoomVersion string `json:"room_version,omitempty"` +} + +type JoinRuleEventContent struct { + JoinRule string `json:"join_rule,omitempty"` +} + +type HistoryVisibilityEventContent struct { + HistoryVisibility string `json:"history_visibility,omitempty"` +} + +type GuestAccessEventContent struct { + GuestAccess string `json:"guest_access,omitempty"` +} + +type NameEventContent struct { + Name string `json:"name,omitempty"` +} + +type TopicEventContent struct { + Topic string `json:"topic,omitempty"` +} + +type MemberEventContent struct { + AvatarUrl string `json:"avatar_url,omitempty"` + DisplayName string `json:"displayname,omitempty"` + Membership string `json:"membership,omitempty"` + IsDirect bool `json:"is_direct,omitempty"` +} + +type PowerLevelsEventContent struct { + Ban int `json:"ban,omitempty"` + Events map[string]int `json:"events,omitempty"` + EventsDefault int `json:"events_default,omitempty"` + Invite int `json:"invite,omitempty"` + Kick int `json:"kick,omitempty"` + Redact int `json:"redact,omitempty"` + StateDefault int `json:"state_default,omitempty"` + Users map[string]int `json:"users,omitempty"` + UsersDefault int `json:"users_default,omitempty"` + Notifications Notifications `json:"notifications,omitempty"` +} + +type Notifications struct { + Room int `json:"room,omitempty"` +} + +type sendMessageRequest struct { + MessageType string `json:"msgtype,omitempty"` + Body string `json:"body,omitempty"` +} + +type createEventResponse struct { + EventId string `json:"event_id,omitempty"` +} + +type getEventRequest struct{} + +type getEventResponse struct { + Content string `json:"content,omitempty"` + EventType string `json:"type,omitempty"` +} + +type syncEventsServerRequest struct { + Origin string `json:"origin,omitempty"` + Timestamp int64 `json:"origin_server_ts,omitempty"` + PDUs []*Event `json:"pdus,omitempty"` +} + +type syncEventsServerResponse struct { + PDUs map[string]pduProcessingResult `json:"pdus,omitempty"` +} + +type backfillResponse struct { + Origin string `json:"origin,omitempty"` + Timestamp int64 `json:"origin_server_ts,omitempty"` + PDUs []*Event `json:"pdus,omitempty"` +} + +type pduProcessingResult struct { + ProcessingError string `json:"error,omitempty"` } diff --git a/entities/event/eventController.go b/entities/event/eventController.go index 9d03846..a64aeaa 100644 --- a/entities/event/eventController.go +++ b/entities/event/eventController.go @@ -1,5 +1,803 @@ package event -func New() (event *Event) { +import ( + "bytes" + "encoding/json" + "fmt" + "log" + "net/http" + "strconv" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/gorilla/mux" + "nutfactory.org/Matrix/config" + "nutfactory.org/Matrix/entities/device" + "nutfactory.org/Matrix/entities/user" + "nutfactory.org/Matrix/utils" +) + +func New( + roomId string, + sender string, + origin string, + timestamp int64, + eventType string, + stateKey string, + content string, + txnId string, +) (err error, newEvent *Event) { + err, eventId := utils.CreateUUID() + if err != nil { + return + } + id := generateEventId(eventId) + newEvent = &Event{ + Id: id, + RoomId: roomId, + Sender: sender, + Origin: origin, + Timestamp: timestamp, + EventType: eventType, + StateKey: stateKey, + Content: content, + Unsigned: UnsignedData{ + TransactionId: txnId, + }, + } + newEvent.AuthEventHashes, err = GetAuthEvents(newEvent) + if err != nil { + return + } + if eventType != "m.room.create" { + var depth int + newEvent.PrevEventHashes, depth, err = ReadEventsWithoutChild(roomId) + if err != nil { + return + } + newEvent.Depth = depth + 1 + } + + newEvent.AuthEventHashes, err = GetAuthEvents(newEvent) + if err != nil { + return + } + + newEventBytesForHash, err := json.Marshal(newEvent) + if err != nil { + return + } + err, newEvent.Hashes.SHA256 = utils.Hash(newEventBytesForHash) + if err != nil { + return + } + newEvent.Unsigned = UnsignedData{} + newEventBytesForSign, err := json.Marshal(newEvent) + if err != nil { + return + } + newEvent.Signatures = utils.SignContent(newEventBytesForSign) + newEvent.Unsigned = UnsignedData{ + TransactionId: txnId, + } + return +} + +func SendMessageHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json; charset=UTF-8") + request := sendMessageRequest{} + errResponse := utils.CheckRequest(r) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + token, errResponse := utils.GetAccessToken(r) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + foundUser, err := user.ReadUserFromAccessToken(token) + if err != nil || foundUser == nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorCode: "M_UNKNOWN_TOKEN", ErrorMessage: fmt.Sprintf("%s", err)}); err != nil { + panic(err) + } + return + } + decoder := json.NewDecoder(r.Body) + err = decoder.Decode(&request) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Could not parse JSON: %s", err)}); err != nil { + panic(err) + } + return + } + vars := mux.Vars(r) + roomId := vars["roomId"] + eventType := vars["eventType"] + txnId := vars["txnId"] + if roomId == "" || eventType == "" || txnId == "" { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: "Missing Parameter"}); err != nil { + panic(err) + } + return + } + buf := new(bytes.Buffer) + buf.ReadFrom(r.Body) + content := buf.String() + err, newEvent := New( + roomId, + foundUser.Id, + config.Homeserver, + time.Now().Unix(), + eventType, + foundUser.Id, + content, + txnId, + ) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Error Creating Event: %s", err)}); err != nil { + panic(err) + } + return + } + err = CreateEvent(newEvent, txnId) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}); err != nil { + panic(err) + } + return + } + transaction := &Transaction{ + Id: txnId, + Origin: config.Homeserver, + Timestamp: time.Now().Unix(), + PDUS: []*Event{newEvent}, + } + servers, err := ReadServers(roomId) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}); err != nil { + panic(err) + } + return + } + for _, server := range servers { + operation := func() error { + return SendTransaction(transaction, server) + } + notify := func(err error, duration time.Duration) { + log.Printf("Error Sending Transaction, retrying in %ss: %s", duration/1000000000, err) + } + go backoff.RetryNotify(operation, backoff.NewExponentialBackOff(), notify) + } + response := createEventResponse{ + EventId: newEvent.Id, + } + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(response); err != nil { + panic(err) + } +} + +func CreateStateEventHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json; charset=UTF-8") + errResponse := utils.CheckRequest(r) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + token, errResponse := utils.GetAccessToken(r) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + foundUser, err := user.ReadUserFromAccessToken(token) + if err != nil || foundUser == nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorCode: "M_UNKNOWN_TOKEN", ErrorMessage: fmt.Sprintf("%s", err)}); err != nil { + panic(err) + } + return + } + vars := mux.Vars(r) + roomId := vars["roomId"] + eventType := vars["eventType"] + stateKey := vars["stateKey"] + if roomId == "" || eventType == "" { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: "Missing Parameter"}); err != nil { + panic(err) + } + return + } + buf := new(bytes.Buffer) + buf.ReadFrom(r.Body) + content := buf.String() + err, newEvent := New( + roomId, + foundUser.Id, + config.Homeserver, + time.Now().Unix(), + eventType, + stateKey, + content, + "", + ) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Error Creating Event: %s", err)}); err != nil { + panic(err) + } + return + } + err = CreateEvent(newEvent, "") + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}); err != nil { + panic(err) + } + return + } + err, txnId := utils.CreateUUID() + transaction := &Transaction{ + Id: txnId, + Origin: config.Homeserver, + Timestamp: time.Now().Unix(), + PDUS: []*Event{newEvent}, + } + servers, err := ReadServers(roomId) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}); err != nil { + panic(err) + } + return + } + for _, server := range servers { + operation := func() error { + return SendTransaction(transaction, server) + } + notify := func(err error, duration time.Duration) { + log.Printf("Error Sending Transaction, retrying in %ss: %s", duration/1000000000, err) + } + go backoff.RetryNotify(operation, backoff.NewExponentialBackOff(), notify) + } + + response := createEventResponse{ + EventId: newEvent.Id, + } + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(response); err != nil { + panic(err) + } +} + +func GetEventUserHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json; charset=UTF-8") + request := getEventRequest{} + errResponse := utils.CheckRequest(r) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + token, errResponse := utils.GetAccessToken(r) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + foundUser, err := user.ReadUserFromAccessToken(token) + if err != nil || foundUser == nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorCode: "M_UNKNOWN_TOKEN", ErrorMessage: fmt.Sprintf("%s", err)}); err != nil { + panic(err) + } + return + } + decoder := json.NewDecoder(r.Body) + err = decoder.Decode(&request) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Could not parse JSON: %s", err)}); err != nil { + panic(err) + } + return + } + vars := mux.Vars(r) + roomId := vars["roomId"] + eventId := vars["eventId"] + if roomId == "" || eventId == "" { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: "Missing Parameter"}); err != nil { + panic(err) + } + return + } + foundEvent, err := ReadEvent(eventId) + if err != nil || foundEvent == nil { + w.WriteHeader(http.StatusNotFound) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorCode: "M_NOT_FOUND", ErrorMessage: fmt.Sprintf("Event not found. %s", err)}); err != nil { + panic(err) + } + return + } + response := getEventResponse{ + Content: foundEvent.Content, + EventType: foundEvent.EventType, + } + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(response); err != nil { + panic(err) + } +} + +func GetStateEventHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json; charset=UTF-8") + errResponse := utils.CheckRequest(r) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + token, errResponse := utils.GetAccessToken(r) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + foundUser, err := user.ReadUserFromAccessToken(token) + if err != nil || foundUser == nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorCode: "M_UNKNOWN_TOKEN", ErrorMessage: fmt.Sprintf("%s", err)}); err != nil { + panic(err) + } + return + } + vars := mux.Vars(r) + roomId := vars["roomId"] + eventType := vars["eventType"] + stateKey := vars["stateKey"] + if roomId == "" || eventType == "" { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: "Missing Parameter"}); err != nil { + panic(err) + } + return + } + foundEvent, err := ReadStateEvent(roomId, eventType, stateKey) + if err != nil || foundEvent == nil { + w.WriteHeader(http.StatusNotFound) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorCode: "M_NOT_FOUND", ErrorMessage: fmt.Sprintf("Event not found. %s", err)}); err != nil { + panic(err) + } + return + } + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(foundEvent.Content); err != nil { + panic(err) + } +} + +func SyncEventsServerHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json; charset=UTF-8") + request := syncEventsServerRequest{} + errResponse := utils.CheckRequest(r) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + errResponse = utils.CheckAuthHeader(r) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&request) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Could not parse JSON: %s", err)}); err != nil { + panic(err) + } + return + } + vars := mux.Vars(r) + txnId := vars["txnId"] + if txnId == "" { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: "Missing Parameter"}); err != nil { + panic(err) + } + return + } + response := syncEventsServerResponse{ + PDUs: make(map[string]pduProcessingResult), + } + missingEventIds := make(map[string][]string) + for _, pdu := range request.PDUs { + signatureValid := CheckSignature(*pdu) + if !signatureValid { + log.Printf("Wrong Signature for Event %s", pdu.Id) + response.PDUs[pdu.Id] = pduProcessingResult{ProcessingError: "Signature not valid"} + continue + } + authEventsValid, err := CheckAuthEvents(pdu) + if !authEventsValid || err != nil { + log.Printf("Wrong Auth Events for Event %s", pdu.Id) + response.PDUs[pdu.Id] = pduProcessingResult{ProcessingError: fmt.Sprintf("Error in Auth Check: %s", err)} + //continue + } + missingParentIds, err := CheckParents(pdu) + if len(missingParentIds) > 0 || err != nil { + response.PDUs[pdu.Id] = pduProcessingResult{ProcessingError: fmt.Sprintf("Error in Parents Check: %s", err)} + for _, parentId := range missingParentIds { + missingEventIds[pdu.RoomId] = append(missingEventIds[pdu.RoomId], parentId) + } + } + foundEvent, err := ReadEvent(pdu.Id) + if foundEvent == nil && err == nil { + err = CreateEvent(pdu, txnId) + } + if err != nil { + response.PDUs[pdu.Id] = pduProcessingResult{ProcessingError: fmt.Sprintf("Database Error: %s", err)} + continue + } + + err = HandleEvent(pdu) + if err != nil { + response.PDUs[pdu.Id] = pduProcessingResult{ProcessingError: fmt.Sprintf("Error in Event-Handling: %s", err)} + continue + } + + response.PDUs[pdu.Id] = pduProcessingResult{} + } + + if len(missingEventIds) > 0 { + for roomId, eventIds := range missingEventIds { + operation := func() error { + return Backfill(eventIds, roomId, request.Origin) + } + notify := func(err error, duration time.Duration) { + log.Printf("Error Backfill, retrying in %s: %s", duration/1000000000, err) + } + go backoff.RetryNotify(operation, backoff.NewExponentialBackOff(), notify) + } + } + + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(response); err != nil { + panic(err) + } +} + +func BackfillHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json; charset=UTF-8") + + errResponse := utils.CheckAuthHeader(r) + + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + + vars := mux.Vars(r) + roomId := vars["roomId"] + if roomId == "" { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: "Missing Parameter"}); err != nil { + panic(err) + } + return + } + + limit := 50 + eventIds, ok := r.URL.Query()["v"] + log.Printf("%s", eventIds) + if !ok || eventIds[0] == "" { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: "Missing Parameter"}); err != nil { + panic(err) + } + return + } + limitString := r.URL.Query().Get("limit") + if limitString != "" { + limit, _ = strconv.Atoi(limitString) + } + pdus := []*Event{} + for len(pdus) < limit { + newEventIds := []string{} + for _, eventId := range eventIds { + foundEvent, err := ReadEvent(eventId) + if err != nil || foundEvent == nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: "Event not found"}); err != nil { + panic(err) + } + return + } + + for newEventId, _ := range foundEvent.PrevEventHashes { + newEventIds = append(newEventIds, newEventId) + } + pdus = append(pdus, foundEvent) + } + eventIds = newEventIds + } + response := backfillResponse{ + Origin: config.Homeserver, + Timestamp: time.Now().Unix(), + PDUs: pdus, + } + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(response); err != nil { + panic(err) + } +} + +func SendTransaction(transaction *Transaction, homeserver string) (err error) { + requestUrl := fmt.Sprintf("https://%s/_matrix/federation/v1/send/%s?", homeserver, transaction.Id) + request := syncEventsServerRequest{ + Origin: transaction.Origin, + Timestamp: transaction.Timestamp, + PDUs: transaction.PDUS, + } + reqBody, err := json.Marshal(request) + if err != nil { + return + } + client := &http.Client{} + req, err := http.NewRequest(http.MethodPut, requestUrl, bytes.NewBuffer(reqBody)) + if err != nil { + return + } + _, err = client.Do(req) + return +} + +func Backfill(eventIds []string, roomId string, homeserver string) (err error) { + requestUrl := fmt.Sprintf("https://%s/_matrix/federation/v1/backfill/%s?", homeserver, roomId) + for _, eventId := range eventIds { + requestUrl = fmt.Sprintf("%sv=%s&", requestUrl, eventId) + } + r, err := http.Get(requestUrl) + if err != nil { + return + } + response := backfillResponse{} + defer r.Body.Close() + decoder := json.NewDecoder(r.Body) + err = decoder.Decode(&response) + if err != nil { + return + } + var missingEventIds []string + for _, pdu := range response.PDUs { + for i, eventId := range missingEventIds { + if pdu.Id == eventId { + missingEventIds = append(missingEventIds[:i], missingEventIds[i+1:]...) + } + } + signatureValid := CheckSignature(*pdu) + if !signatureValid { + log.Printf("Wrong Signature for Event %s", pdu.Id) + missingEventIds = append(missingEventIds, pdu.Id) + continue + } + authEventsValid, err := CheckAuthEvents(pdu) + if !authEventsValid || err != nil { + log.Printf("Wrong Auth Events for Event %s", pdu.Id) + missingEventIds = append(missingEventIds, pdu.Id) + continue + } + missingParentIds, err := CheckParents(pdu) + if len(missingParentIds) > 0 || err != nil { + for _, parentId := range missingParentIds { + missingEventIds = append(missingEventIds, parentId) + } + } + foundEvent, err := ReadEvent(pdu.Id) + if foundEvent == nil && err == nil { + err = CreateEvent(pdu, "") + } + if err != nil { + continue + } + err = HandleEvent(pdu) + if err != nil { + log.Printf("Error in Event-Handling: %s", err) + continue + } + } + + if len(missingEventIds) > 0 { + Backfill(missingEventIds, roomId, homeserver) + } + + return +} + +func generateEventId(id string) string { + return fmt.Sprintf("$%s:%s", id, config.Homeserver) +} + +func GetAuthChain(newEvent *Event) (authChain []*Event, err error) { + createEvent, err := ReadStateEvent(newEvent.RoomId, "m.room.create", "") + if err != nil { + return + } + if createEvent != nil { + authChain = append(authChain, createEvent) + } + + powerLevelEvent, err := ReadStateEvent(newEvent.RoomId, "m.room.power_levels", "") + if err != nil { + return + } + if powerLevelEvent != nil { + authChain = append(authChain, powerLevelEvent) + } + + stateKey := newEvent.Sender + if newEvent.EventType == "m.room.member" { + stateKey = newEvent.StateKey + } + + memberEvent, err := ReadStateEvent(newEvent.RoomId, "m.room.member", stateKey) + if err != nil { + return + } + if memberEvent != nil { + authChain = append(authChain, memberEvent) + } + + joinRuleEvent, err := ReadStateEvent(newEvent.RoomId, "m.room.join_rules", "") + if err != nil { + return + } + if joinRuleEvent != nil && newEvent.EventType == "m.room.member" { + authChain = append(authChain, joinRuleEvent) + } + + return +} + +func GetAuthEvents(newEvent *Event) (authEventHashes map[string]EventHash, err error) { + authEventHashes = make(map[string]EventHash) + + authChain, err := GetAuthChain(newEvent) + if err != nil { + return + } + + for _, authEvent := range authChain { + authEventHashes[authEvent.Id] = authEvent.Hashes + } + + return +} + +func CheckEventHash(id string, hash EventHash) (correct bool, eventFound bool, err error) { + foundEvent, err := ReadEvent(id) + correct = true + eventFound = true + if err != nil { + return + } + if foundEvent == nil { + eventFound = false + return + } + if hash.SHA256 != foundEvent.Hashes.SHA256 { + correct = false + return + } + return +} + +func CheckParents(eventToCheck *Event) (missingParentIds []string, err error) { + for key, hash := range eventToCheck.PrevEventHashes { + correctHash, foundEvent, err := CheckEventHash(key, hash) + if !correctHash || !foundEvent || err != nil { + missingParentIds = append(missingParentIds, key) + } + } + return +} + +func CheckAuthEvents(eventToCheck *Event) (correct bool, err error) { + correct = true + authEvents, err := GetAuthEvents(eventToCheck) + if err != nil { + return + } + for key, hash := range authEvents { + if eventToCheck.AuthEventHashes[key].SHA256 != hash.SHA256 { + correct = false + return + } + } + return +} + +func CheckSignature(eventToCheck Event) (correct bool) { + correct = false + signatures := eventToCheck.Signatures + eventToCheck.Unsigned = UnsignedData{} + eventToCheck.Signatures = nil + jsonString, err := json.Marshal(eventToCheck) + if err != nil { + return + } + for id, signature := range signatures[eventToCheck.Sender] { + key, err := device.ReadKey(id) + if err == nil { + correct = utils.VerifySignature([]byte(key.Key), jsonString, []byte(signature)) + } + } + return +} + +func HandleEvents(events []*Event) (err error) { + for _, eventToHandle := range events { + err = HandleEvent(eventToHandle) + } + return +} + +func HandleEvent(eventToHandle *Event) (err error) { + if eventToHandle.EventType == "m.room.message" { + message := sendMessageRequest{} + err = json.Unmarshal([]byte(eventToHandle.Content), &message) + if err != nil { + return + } + if message.MessageType != "" && message.MessageType == "m.text" { + log.Printf("%s: %s", eventToHandle.Sender, message.Body) + } + } else if eventToHandle.EventType == "m.room.member" { + message := MemberEventContent{} + err = json.Unmarshal([]byte(eventToHandle.Content), &message) + if err != nil { + return + } + if message.Membership == "join" { + CreateRoomMember(eventToHandle.RoomId, eventToHandle.StateKey) + } + } return } diff --git a/entities/event/eventDatabaseConnector.go b/entities/event/eventDatabaseConnector.go index ecc7cde..8b37745 100644 --- a/entities/event/eventDatabaseConnector.go +++ b/entities/event/eventDatabaseConnector.go @@ -2,15 +2,16 @@ package event import ( "fmt" + "strings" "nutfactory.org/Matrix/utils/database" ) -func CreateEvent(event *Event, txnId string) (err error) { - sqlStmt := fmt.Sprintf(`INSERT INTO event - (id, roomId, txnId, eventType, content, parentId, depth) +func CreateRoomMember(roomId string, userId string) (err error) { + sqlStmt := fmt.Sprintf(`INSERT INTO roomMember + (roomId, userId, server) VALUES - (?, ?, ?, ?, ?, ?, ?)`) + (?, ?, ?)`) tx, err := database.DB.Begin() if err != nil { @@ -23,7 +24,7 @@ func CreateEvent(event *Event, txnId string) (err error) { } defer stmt.Close() - _, err = stmt.Exec(event.Id, event.RoomId, txnId, event.EventType, event.Content, event.ParentId, event.Depth) + _, err = stmt.Exec(roomId, userId, strings.Split(userId, ":")[1]) if err != nil { return } @@ -31,11 +32,129 @@ func CreateEvent(event *Event, txnId string) (err error) { return } +func CreateParents(eventId string, parentIds map[string]EventHash) (err error) { + sqlStmt := fmt.Sprintf(`INSERT INTO parent + (eventId, parentId) + VALUES + (?, ?)`) + + tx, err := database.DB.Begin() + if err != nil { + return + } + + stmt, err := tx.Prepare(sqlStmt) + if err != nil { + return + } + defer stmt.Close() + + for parentId, _ := range parentIds { + _, err = stmt.Exec( + eventId, + parentId, + ) + if err != nil { + return + } + } + + tx.Commit() + + return +} + +func CreateAuthEvents(eventId string, authEventIds map[string]EventHash) (err error) { + sqlStmt := fmt.Sprintf(`INSERT INTO authEvent + (eventId, authEventId) + VALUES + (?, ?)`) + + tx, err := database.DB.Begin() + if err != nil { + return + } + + stmt, err := tx.Prepare(sqlStmt) + if err != nil { + return + } + defer stmt.Close() + + for authEventId, _ := range authEventIds { + _, err = stmt.Exec( + eventId, + authEventId, + ) + if err != nil { + return + } + } + + tx.Commit() + + return +} + +func CreateEvent(event *Event, txnId string) (err error) { + sqlStmt := fmt.Sprintf(`INSERT INTO event + (id, roomId, txnId, sender, origin, timestamp, eventType, stateKey, content, depth, hash, signature) + VALUES + (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) + + tx, err := database.DB.Begin() + if err != nil { + return + } + + stmt, err := tx.Prepare(sqlStmt) + if err != nil { + return + } + defer stmt.Close() + + signatures := "" + for _, signature := range event.Signatures[event.Origin] { + signatures = signature + } + + _, err = stmt.Exec( + event.Id, + event.RoomId, + txnId, + event.Sender, + event.Origin, + event.Timestamp, + event.EventType, + event.StateKey, + event.Content, + event.Depth, + event.Hashes.SHA256, + signatures, + ) + if err != nil { + return + } + tx.Commit() + + err = CreateParents(event.Id, event.PrevEventHashes) + if err != nil { + return + } + + err = CreateAuthEvents(event.Id, event.AuthEventHashes) + if err != nil { + return + } + + return +} + func CreateEventsFromTransaction(txnId string, pdus map[string]*Event) (err error) { sqlStmt := fmt.Sprintf(`INSERT INTO event - (id, roomId, txnId, eventType, content, parentId, depth) + (id, roomId, txnId, sender, origin, timestamp, eventType, stateKey, content, depth, hash, signature) VALUES - (?, ?, ?, ?, ?, ?, ?)`) + (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) tx, err := database.DB.Begin() if err != nil { @@ -49,7 +168,35 @@ func CreateEventsFromTransaction(txnId string, pdus map[string]*Event) (err erro defer stmt.Close() for _, pdu := range pdus { - _, err = stmt.Exec(pdu.Id, pdu.RoomId, txnId, pdu.EventType, pdu.Content, pdu.ParentId, pdu.Depth) + signatures := "" + for _, signature := range pdu.Signatures[pdu.Origin] { + signatures = signature + } + + _, err = stmt.Exec( + pdu.Id, + pdu.RoomId, + txnId, + pdu.Sender, + pdu.Origin, + pdu.Timestamp, + pdu.EventType, + pdu.StateKey, + pdu.Content, + pdu.Depth, + pdu.Hashes.SHA256, + signatures, + ) + if err != nil { + return + } + + err = CreateParents(pdu.Id, pdu.PrevEventHashes) + if err != nil { + return + } + + err = CreateAuthEvents(pdu.Id, pdu.AuthEventHashes) if err != nil { return } @@ -59,8 +206,186 @@ func CreateEventsFromTransaction(txnId string, pdus map[string]*Event) (err erro return } +func ReadEventHash(id string) (hash string, err error) { + queryStmt := fmt.Sprintf(`SELECT hash + FROM event + WHERE id = '%s'`, id) + + rows, err := database.DB.Query(queryStmt) + if err != nil { + return + } + + defer rows.Close() + + if rows.Next() { + err = rows.Scan(hash) + if err != nil { + return + } + } + + return +} + +func ReadRoomMembers(roomId string) (roomMembers []string, err error) { + queryStmt := fmt.Sprintf(`SELECT userId + FROM roomMember + WHERE roomId = '%s'`, roomId) + + rows, err := database.DB.Query(queryStmt) + if err != nil { + return + } + + defer rows.Close() + + roomMembers = []string{} + + for rows.Next() { + var foundUser string + err = rows.Scan(&foundUser) + if err != nil { + return + } + roomMembers = append(roomMembers, foundUser) + } + + return +} + +func ReadServers(roomId string) (servers []string, err error) { + queryStmt := fmt.Sprintf(`SELECT DISTINCT server + FROM roomMember + WHERE roomId = '%s'`, roomId) + + rows, err := database.DB.Query(queryStmt) + if err != nil { + return + } + + defer rows.Close() + + servers = []string{} + + for rows.Next() { + var foundUser string + err = rows.Scan(&foundUser) + if err != nil { + return + } + servers = append(servers, foundUser) + } + + return +} + +func ReadParents(eventId string) (parents map[string]EventHash, err error) { + queryStmt := fmt.Sprintf(`SELECT e.id, e.hash + FROM event as e + join parent as p on e.id = p.parentId + WHERE p.eventId = '%s'`, eventId) + + rows, err := database.DB.Query(queryStmt) + if err != nil { + return + } + + defer rows.Close() + + parents = make(map[string]EventHash) + + for rows.Next() { + var eventId string + var foundEvent EventHash + + err = rows.Scan(&eventId, + &foundEvent.SHA256, + ) + + if err != nil { + return + } + + parents[eventId] = foundEvent + } + + return +} + +func ReadEventsWithoutChild(roomId string) (events map[string]EventHash, depth int, err error) { + queryStmt := fmt.Sprintf(`SELECT e.id, e.hash, e.depth + FROM event as e + LEFT JOIN parent as p on e.id = p.parentId + WHERE p.eventId IS NULL AND e.roomId = '%s'`, roomId) + + rows, err := database.DB.Query(queryStmt) + if err != nil { + return + } + + defer rows.Close() + + events = make(map[string]EventHash) + + for rows.Next() { + var eventId string + var foundEvent EventHash + var foundDepth int + err = rows.Scan(&eventId, + &foundEvent.SHA256, + &foundDepth, + ) + + if foundDepth > depth { + depth = foundDepth + } + + if err != nil { + return + } + + events[eventId] = foundEvent + } + + return +} + +func ReadAuthEvents(eventId string) (authEvents map[string]EventHash, err error) { + queryStmt := fmt.Sprintf(`SELECT e.id, e.hash + FROM event as e + join authEvent as a on e.id = a.authEventId + WHERE a.eventId = '%s'`, eventId) + + rows, err := database.DB.Query(queryStmt) + if err != nil { + return + } + + defer rows.Close() + + authEvents = make(map[string]EventHash) + + for rows.Next() { + var eventId string + var foundEvent EventHash + + err = rows.Scan(&eventId, + &foundEvent.SHA256, + ) + + if err != nil { + return + } + + authEvents[eventId] = foundEvent + } + + return +} + func ReadEvent(id string) (foundEvent *Event, err error) { - queryStmt := fmt.Sprintf(`SELECT id, roomId, eventType, content, parentId, depth + queryStmt := fmt.Sprintf(`SELECT id, roomId, txnId, sender, origin, timestamp, eventType, content, depth, hash, signature FROM event WHERE id = '%s'`, id) @@ -73,13 +398,34 @@ func ReadEvent(id string) (foundEvent *Event, err error) { if rows.Next() { foundEvent = &Event{} + var signature string err = rows.Scan(&foundEvent.Id, &foundEvent.RoomId, + &foundEvent.Unsigned.TransactionId, + &foundEvent.Sender, + &foundEvent.Origin, + &foundEvent.Timestamp, &foundEvent.EventType, &foundEvent.Content, - &foundEvent.ParentId, &foundEvent.Depth, + &foundEvent.Hashes.SHA256, + &signature, ) + + foundEvent.Signatures = make(map[string]map[string]string) + foundEvent.Signatures[foundEvent.Origin] = make(map[string]string) + foundEvent.Signatures[foundEvent.Origin]["ed25519:1"] = signature + + if err != nil { + return + } + + foundEvent.PrevEventHashes, err = ReadParents(foundEvent.Id) + if err != nil { + return + } + + foundEvent.AuthEventHashes, err = ReadAuthEvents(foundEvent.Id) if err != nil { return } @@ -88,8 +434,120 @@ func ReadEvent(id string) (foundEvent *Event, err error) { return } +func ReadStateEvent(roomId string, eventType string, stateKey string) (foundEvent *Event, err error) { + queryStmt := fmt.Sprintf(`SELECT id, roomId, txnId, sender, origin, timestamp, eventType, content, depth, hash, signature + FROM event + WHERE roomId = '%s' + AND eventType = '%s' + AND stateKey = '%s'`, roomId, eventType, stateKey) + + if stateKey == "" { + queryStmt = fmt.Sprintf(`SELECT id, roomId, txnId, sender, origin, timestamp, eventType, content, depth, hash, signature + FROM event + WHERE roomId = '%s' + AND eventType = '%s'`, roomId, eventType) + } + + rows, err := database.DB.Query(queryStmt) + if err != nil { + return + } + + defer rows.Close() + + if rows.Next() { + foundEvent = &Event{} + var signature string + err = rows.Scan(&foundEvent.Id, + &foundEvent.RoomId, + &foundEvent.Unsigned.TransactionId, + &foundEvent.Sender, + &foundEvent.Origin, + &foundEvent.Timestamp, + &foundEvent.EventType, + &foundEvent.Content, + &foundEvent.Depth, + &foundEvent.Hashes.SHA256, + &signature, + ) + + foundEvent.Signatures = make(map[string]map[string]string) + foundEvent.Signatures[foundEvent.Origin] = make(map[string]string) + foundEvent.Signatures[foundEvent.Origin]["ed25519:1"] = signature + + if err != nil { + return + } + + foundEvent.PrevEventHashes, err = ReadParents(foundEvent.Id) + if err != nil { + return + } + + foundEvent.AuthEventHashes, err = ReadAuthEvents(foundEvent.Id) + if err != nil { + return + } + } + + return +} + +func ReadStateEvents(roomId string, eventType string) (foundEvents []*Event, err error) { + queryStmt := fmt.Sprintf(`SELECT id, roomId, txnId, sender, origin, timestamp, eventType, content, depth, hash, signature + FROM event + WHERE roomId = '%s' + AND eventType = '%s'`, roomId, eventType) + + rows, err := database.DB.Query(queryStmt) + if err != nil { + return + } + + defer rows.Close() + + for rows.Next() { + foundEvent := &Event{} + var signature string + err = rows.Scan(&foundEvent.Id, + &foundEvent.RoomId, + &foundEvent.Unsigned.TransactionId, + &foundEvent.Sender, + &foundEvent.Origin, + &foundEvent.Timestamp, + &foundEvent.EventType, + &foundEvent.Content, + &foundEvent.Depth, + &foundEvent.Hashes.SHA256, + &signature, + ) + + foundEvent.Signatures = make(map[string]map[string]string) + foundEvent.Signatures[foundEvent.Origin] = make(map[string]string) + foundEvent.Signatures[foundEvent.Origin]["ed25519:1"] = signature + + if err != nil { + return + } + + foundEvent.PrevEventHashes, err = ReadParents(foundEvent.Id) + if err != nil { + return + } + + foundEvent.AuthEventHashes, err = ReadAuthEvents(foundEvent.Id) + if err != nil { + return + } + + foundEvents = append(foundEvents, foundEvent) + } + + return +} + func ReadEventsFromRoom(roomId string) (events map[string]*Event, err error) { - queryStmt := fmt.Sprintf(`SELECT id, roomId, eventType, content, parentId, depth + queryStmt := fmt.Sprintf(`SELECT id, roomId, txnId, sender, origin, timestamp, eventType, content, depth, hash, signature FROM event WHERE roomId = '%s'`, roomId) @@ -104,24 +562,98 @@ func ReadEventsFromRoom(roomId string) (events map[string]*Event, err error) { for rows.Next() { foundEvent := &Event{} + var signature string err = rows.Scan(&foundEvent.Id, &foundEvent.RoomId, + &foundEvent.Unsigned.TransactionId, + &foundEvent.Sender, + &foundEvent.Origin, + &foundEvent.Timestamp, &foundEvent.EventType, &foundEvent.Content, - &foundEvent.ParentId, &foundEvent.Depth, + &foundEvent.Hashes.SHA256, + &signature, ) + + foundEvent.Signatures = make(map[string]map[string]string) + foundEvent.Signatures[foundEvent.Origin] = make(map[string]string) + foundEvent.Signatures[foundEvent.Origin]["ed25519:1"] = signature + if err != nil { return } + + foundEvent.PrevEventHashes, err = ReadParents(foundEvent.Id) + if err != nil { + return + } + + foundEvent.AuthEventHashes, err = ReadAuthEvents(foundEvent.Id) + if err != nil { + return + } + events[foundEvent.Id] = foundEvent } return } -func ReadEventsFromTransaction(txnId string) (events map[string]*Event, err error) { - queryStmt := fmt.Sprintf(`SELECT id, roomId, eventType, content, parentId, depth +func ReadStateEventsFromRoom(roomId string) (events []*Event, err error) { + queryStmt := fmt.Sprintf(`SELECT id, roomId, txnId, sender, origin, timestamp, eventType, content, depth, hash, signature + FROM event + WHERE eventType <> 'm.room.message' AND roomId = '%s'`, roomId) + + rows, err := database.DB.Query(queryStmt) + if err != nil { + return + } + + defer rows.Close() + + for rows.Next() { + foundEvent := &Event{} + var signature string + err = rows.Scan(&foundEvent.Id, + &foundEvent.RoomId, + &foundEvent.Unsigned.TransactionId, + &foundEvent.Sender, + &foundEvent.Origin, + &foundEvent.Timestamp, + &foundEvent.EventType, + &foundEvent.Content, + &foundEvent.Depth, + &foundEvent.Hashes.SHA256, + &signature, + ) + + foundEvent.Signatures = make(map[string]map[string]string) + foundEvent.Signatures[foundEvent.Origin] = make(map[string]string) + foundEvent.Signatures[foundEvent.Origin]["ed25519:1"] = signature + + if err != nil { + return + } + + foundEvent.PrevEventHashes, err = ReadParents(foundEvent.Id) + if err != nil { + return + } + + foundEvent.AuthEventHashes, err = ReadAuthEvents(foundEvent.Id) + if err != nil { + return + } + + events = append(events, foundEvent) + } + + return +} + +func ReadEventsFromTransaction(txnId string) (events []*Event, err error) { + queryStmt := fmt.Sprintf(`SELECT id, roomId, txnId, sender, origin, timestamp, eventType, content, depth, hash, signature FROM event WHERE txnId = '%s'`, txnId) @@ -132,22 +664,41 @@ func ReadEventsFromTransaction(txnId string) (events map[string]*Event, err erro defer rows.Close() - events = make(map[string]*Event) - for rows.Next() { foundEvent := &Event{} - err = rows.Scan( - &foundEvent.Id, + var signature string + err = rows.Scan(&foundEvent.Id, &foundEvent.RoomId, + &foundEvent.Unsigned.TransactionId, + &foundEvent.Sender, + &foundEvent.Origin, + &foundEvent.Timestamp, &foundEvent.EventType, &foundEvent.Content, - &foundEvent.ParentId, &foundEvent.Depth, + &foundEvent.Hashes.SHA256, + &signature, ) + + foundEvent.Signatures = make(map[string]map[string]string) + foundEvent.Signatures[foundEvent.Origin] = make(map[string]string) + foundEvent.Signatures[foundEvent.Origin]["ed25519:1"] = signature + if err != nil { return } - events[foundEvent.Id] = foundEvent + + foundEvent.PrevEventHashes, err = ReadParents(foundEvent.Id) + if err != nil { + return + } + + foundEvent.AuthEventHashes, err = ReadAuthEvents(foundEvent.Id) + if err != nil { + return + } + + events = append(events, foundEvent) } return @@ -155,11 +706,8 @@ func ReadEventsFromTransaction(txnId string) (events map[string]*Event, err erro func UpdateEvent(event *Event) (err error) { sqlStmt := fmt.Sprintf(`UPDATE event SET - roomId = ?, eventType = ?, content = ?, - parentId = ?, - depth = ? WHERE id = ?`) tx, err := database.DB.Begin() @@ -174,11 +722,8 @@ func UpdateEvent(event *Event) (err error) { defer stmt.Close() _, err = stmt.Exec( - event.RoomId, event.EventType, event.Content, - event.ParentId, - event.Depth, event.Id, ) if err != nil { @@ -203,6 +748,52 @@ func DeleteEvent(id string) (err error) { return } + err = DeleteParents(id) + if err != nil { + return + } + + err = DeleteAuthEvents(id) + if err != nil { + return + } + + tx.Commit() + return +} + +func DeleteParents(eventId string) (err error) { + queryStmt := fmt.Sprintf(`DELETE FROM parent + WHERE eventId = '%s'`, eventId) + + tx, err := database.DB.Begin() + if err != nil { + return + } + + _, err = database.DB.Exec(queryStmt) + if err != nil { + return + } + + tx.Commit() + return +} + +func DeleteAuthEvents(eventId string) (err error) { + queryStmt := fmt.Sprintf(`DELETE FROM authEvent + WHERE eventId = '%s'`, eventId) + + tx, err := database.DB.Begin() + if err != nil { + return + } + + _, err = database.DB.Exec(queryStmt) + if err != nil { + return + } + tx.Commit() return } diff --git a/entities/event/transaction.go b/entities/event/transaction.go new file mode 100644 index 0000000..a1e251e --- /dev/null +++ b/entities/event/transaction.go @@ -0,0 +1,8 @@ +package event + +type Transaction struct { + Id string `json:"id,omitempty"` + Origin string `json:"origin,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` + PDUS []*Event `json:"pdus,omitempty"` +} diff --git a/entities/transaction/transactionDatabaseConnector.go b/entities/event/transactionDatabaseConnector.go similarity index 91% rename from entities/transaction/transactionDatabaseConnector.go rename to entities/event/transactionDatabaseConnector.go index d5e4c90..fb63d01 100644 --- a/entities/transaction/transactionDatabaseConnector.go +++ b/entities/event/transactionDatabaseConnector.go @@ -1,9 +1,8 @@ -package transaction +package event import ( "fmt" - "nutfactory.org/Matrix/entities/event" "nutfactory.org/Matrix/utils/database" ) @@ -50,7 +49,7 @@ func ReadTransaction(id string) (foundTransaction *Transaction, err error) { if err != nil { return } - foundTransaction.PDUS, err = event.ReadEventsFromTransaction(foundTransaction.Id) + foundTransaction.PDUS, err = ReadEventsFromTransaction(foundTransaction.Id) } return diff --git a/entities/general/generalController.go b/entities/general/generalController.go new file mode 100644 index 0000000..c28a26a --- /dev/null +++ b/entities/general/generalController.go @@ -0,0 +1,94 @@ +package general + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "os" + + "nutfactory.org/Matrix/config" + "nutfactory.org/Matrix/entities/device" + "nutfactory.org/Matrix/utils" + "nutfactory.org/Matrix/utils/database" +) + +type resolveServerNameResponse struct { + Server string `json:"m.server,omitempty"` +} + +type getServerImplementationResponse struct { + Server serverImplementation `json:"server,omitempty"` +} + +type serverImplementation struct { + Name string `json:"name,omitempty"` + Version string `json:"version,omitempty"` +} + +type resetBody struct { + Packetloss float32 `json:"packetloss,omitempty"` + UnavailableTill int `json:"unavailableTill,omitempty"` + AuthentificationCheck bool `json:"authentificationCheck,omitempty"` + Signing bool `json:"signing,omitempty"` + Encryption bool `json:"encryption,omitempty"` +} + +func ResolveServerName(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json; charset=UTF-8") + response := resolveServerNameResponse{Server: fmt.Sprintf("%s:%s", config.Homeserver, config.Port)} + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(response); err != nil { + panic(err) + } +} + +func GetServerImplementation(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json; charset=UTF-8") + response := getServerImplementationResponse{Server: serverImplementation{Name: config.ServerName, Version: config.Version}} + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(response); err != nil { + panic(err) + } +} + +func Reset(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json; charset=UTF-8") + request := resetBody{} + errResponse := utils.CheckRequest(r) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&request) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Could not parse JSON: %s", err)}); err != nil { + panic(err) + } + return + } + + if err := device.InitServerSigningKey(); err != nil { + log.Fatal(err) + } + config.VerifyKeys = make(map[string]map[string][]byte) + + os.Remove("sqlite.db") + + if err := database.InitDB("sqlite.db"); err != nil { + log.Fatal(err) + } + + config.Packetloss = request.Packetloss + config.UnavailableTill = request.UnavailableTill + config.AuthentificationCheck = request.AuthentificationCheck + config.Signing = request.Signing + config.Encryption = request.Signing + + w.WriteHeader(http.StatusOK) +} diff --git a/entities/room/room.go b/entities/room/room.go index 712810c..8f5b389 100644 --- a/entities/room/room.go +++ b/entities/room/room.go @@ -1,13 +1,114 @@ package room -import ( - "nutfactory.org/Matrix/entities/event" -) +import "nutfactory.org/Matrix/entities/event" type Room struct { - Id string `json:"id,omitempty"` - Messages map[string]*event.Event `json:"messages,omitempty"` - //State map[string]event.Event `json:"state,omitempty"` - Members []string `json:"members,omitempty"` - Version string `json:"version,omitempty"` + Id string `json:"id,omitempty"` + Version string `json:"version,omitempty"` + Name string `json:"name,omitempty"` + Topic string `json:"topic,omitempty"` + Members []string `json:"members,omitempty"` + Servers []string + Events map[string]*event.Event `json:"events,omitempty"` + Visibility string `json:"visibility,omitempty"` + IsDirect bool `json:"is_direct,omitempty"` + Federated bool `json:"federated,omitempty"` } + +type createRoomRequest struct { + Visibility string `json:"visibility,omitempty"` + RoomAliasName string `json:"room_alias_name,omitempty"` + Name string `json:"name,omitempty"` + Topic string `json:"topic,omitempty"` + Invite string `json:"invite,omitempty"` + Invite3pid invite3pid `json:"invite_3pid,omitempty"` + RoomVersion string `json:"room_version,omitempty"` + CreationContent creationContent `json:"creation_content,omitempty"` + InitialState []event.StateEvent `json:"initial_state,omitempty"` + Preset string `json:"preset,omitempty"` + IsDirect bool `json:"is_direct,omitempty"` + PowerLevelContentOverride string `json:"power_level_content_override,omitempty"` +} + +type createRoomResponse struct { + RoomId string `json:"room_id,omitempty"` +} + +type getRoomMemberRequest struct{} + +type getRoomMemberResponse struct { + Chunk []*event.Event `json:"chunk,omitempty"` +} + +type joinRoomUserRequest struct { + ThirdPartySigned thirdPartySigned `json:"third_party_signed,omitempty"` +} + +type joinRoomUserResponse struct { + RoomId string `json:"room_id,omitempty"` +} + +type leaveRoomUserRequest struct{} + +type leaveRoomUserResponse struct{} + +type makeJoinRequest struct{} + +type makeJoinResponse struct { + RoomVersion string `json:"room_version,omitempty"` + Event event.Event `json:"event,omitempty"` +} + +type joinRoomServerRequest struct { + Sender string `json:"sender,omitempty"` + Origin string `json:"origin,omitempty"` + Timestamp int64 `json:"origin_server_ts,omitempty"` + EventType string `json:"type,omitempty"` + StateKey string `json:"state_key,omitempty"` + Content event.MemberEventContent `json:"content,omitempty"` +} + +type joinRoomServerResponse struct { + Origin string `json:"origin,omitempty"` + AuthChain []*event.Event `json:"auth_chain,omitempty"` + State []*event.Event `json:"state,omitempty"` +} + +type makeLeaveRequest struct{} + +type makeLeaveResponse struct{} + +type leaveRoomServerRequest struct{} + +type leaveRoomServerResponse struct{} + +type invite3pid struct { + IdServer string `json:"id_server,omitempty"` + IdAccessToken string `json:"id_access_token,omitempty"` + Medium string `json:"medium,omitempty"` + Address string `json:"address,omitempty"` +} + +type creationContent struct { + Federated bool `json:"m.federate,omitempty"` +} + +type unsignedData struct { + Age int `json:"age,omitempty"` + RedactedBecause *event.Event `json:"redacted_because,omitempty"` + TransactionId string `json:"transaction_id,omitempty"` +} + +type invite struct { + DisplayName string `json:"display_name,omitempty"` + Signed thirdPartySigned `json:"signed,omitempty"` +} + +type thirdPartySigned struct { + Sender string `json:"sender,omitempty"` + MXID string `json:"mxid,omitempty"` + Signatures signatures `json:"signatures,omitempty"` + Token string `json:"token,omitempty"` +} + +type signatures struct{} diff --git a/entities/room/roomController.go b/entities/room/roomController.go index 859931a..4804cea 100644 --- a/entities/room/roomController.go +++ b/entities/room/roomController.go @@ -1,5 +1,770 @@ package room -func New() (room *Room) { +import ( + "bytes" + "encoding/json" + "fmt" + "log" + "net/http" + "strings" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/gorilla/mux" + "nutfactory.org/Matrix/config" + "nutfactory.org/Matrix/entities/event" + "nutfactory.org/Matrix/entities/user" + "nutfactory.org/Matrix/utils" +) + +func New( + version string, + name string, + topic string, + visibility string, + isDirect bool, + federated bool, + creatorId string, +) (err error, newRoom *Room) { + err, roomId := utils.CreateUUID() + if err != nil { + return + } + id := generateRoomId(roomId) + newRoom = &Room{ + Id: id, + Version: version, + Name: name, + Topic: topic, + Members: []string{creatorId}, + Events: make(map[string]*event.Event), + Visibility: visibility, + IsDirect: isDirect, + Federated: federated, + } return } + +func CreateRoomHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json; charset=UTF-8") + request := createRoomRequest{} + errResponse := utils.CheckRequest(r) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + token, errResponse := utils.GetAccessToken(r) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + foundUser, err := user.ReadUserFromAccessToken(token) + if err != nil || foundUser == nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorCode: "M_UNKNOWN_TOKEN", ErrorMessage: fmt.Sprintf("%s", err)}); err != nil { + panic(err) + } + return + } + decoder := json.NewDecoder(r.Body) + err = decoder.Decode(&request) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Could not parse JSON: %s", err)}); err != nil { + panic(err) + } + return + } + + err, newRoom := New( + request.RoomVersion, + request.Name, + request.Topic, + request.Visibility, + request.IsDirect, + request.CreationContent.Federated, + foundUser.Id, + ) + err = CreateRoom(newRoom) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}); err != nil { + panic(err) + } + return + } + createEventContent := event.CreateEventContent{ + Creator: foundUser.Id, + Federated: request.CreationContent.Federated, + RoomVersion: newRoom.Version, + } + createEventContentBytes, _ := json.Marshal(createEventContent) + err, createEvent := event.New( + newRoom.Id, + foundUser.Id, + config.Homeserver, + time.Now().Unix(), + "m.room.create", + "", + string(createEventContentBytes), + "", + ) + if err == nil { + err = event.CreateEvent(createEvent, "") + } + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Error Event-Creation: %s", err)}); err != nil { + panic(err) + } + return + } + memberEventContent := event.MemberEventContent{ + DisplayName: foundUser.Name, + IsDirect: request.IsDirect, + Membership: "join", + } + memberEventContentBytes, _ := json.Marshal(memberEventContent) + err, memberEvent := event.New( + newRoom.Id, + foundUser.Id, + config.Homeserver, + time.Now().Unix(), + "m.room.member", + foundUser.Id, + string(memberEventContentBytes), + "", + ) + if err == nil { + err = event.CreateEvent(memberEvent, "") + } + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Error Event-Creation: %s", err)}); err != nil { + panic(err) + } + return + } + userPowerLevel := make(map[string]int) + userPowerLevel[foundUser.Id] = 100 + powerLevelEventContent := event.PowerLevelsEventContent{ + Ban: 50, + EventsDefault: 0, + Invite: 50, + Kick: 50, + Redact: 50, + StateDefault: 50, + Users: userPowerLevel, + UsersDefault: 0, + Notifications: event.Notifications{ + Room: 50, + }, + } + powerLevelEventContentBytes, _ := json.Marshal(powerLevelEventContent) + err, powerLevelEvent := event.New( + newRoom.Id, + foundUser.Id, + config.Homeserver, + time.Now().Unix(), + "m.room.power_levels", + "", + string(powerLevelEventContentBytes), + "", + ) + if err == nil { + err = event.CreateEvent(powerLevelEvent, "") + } + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Error Event-Creation: %s", err)}); err != nil { + panic(err) + } + return + } + joinRule := "invite" + historyVisibilty := "shared" + guestAccess := "can_join" + if request.Preset == "public_chat" { + joinRule = "public" + guestAccess = "forbidden" + } + joinRuleEventContent := event.JoinRuleEventContent{ + JoinRule: joinRule, + } + joinRuleEventContentBytes, _ := json.Marshal(joinRuleEventContent) + err, joinRulesEvent := event.New( + newRoom.Id, + foundUser.Id, + config.Homeserver, + time.Now().Unix(), + "m.room.join_rules", + "", + string(joinRuleEventContentBytes), + "", + ) + if err == nil { + err = event.CreateEvent(joinRulesEvent, "") + } + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Error Event-Creation: %s", err)}); err != nil { + panic(err) + } + return + } + historyVisiblilityEventContent := event.HistoryVisibilityEventContent{ + HistoryVisibility: historyVisibilty, + } + historyVisiblilityContentBytes, _ := json.Marshal(historyVisiblilityEventContent) + err, historyVisibilityEvent := event.New( + newRoom.Id, + foundUser.Id, + config.Homeserver, + time.Now().Unix(), + "m.room.history_visibility", + "", + string(historyVisiblilityContentBytes), + "", + ) + if err == nil { + err = event.CreateEvent(historyVisibilityEvent, "") + } + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Error Event-Creation: %s", err)}); err != nil { + panic(err) + } + return + } + guestAccessEventContent := event.GuestAccessEventContent{ + GuestAccess: guestAccess, + } + guestAccessContentBytes, _ := json.Marshal(guestAccessEventContent) + err, guestAccessEvent := event.New( + newRoom.Id, + foundUser.Id, + config.Homeserver, + time.Now().Unix(), + "m.room.guest_access", + "", + string(guestAccessContentBytes), + "", + ) + if err == nil { + err = event.CreateEvent(guestAccessEvent, "") + } + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Error Event-Creation: %s", err)}); err != nil { + panic(err) + } + return + } + nameEventContent := event.NameEventContent{ + Name: newRoom.Name, + } + nameContentBytes, _ := json.Marshal(nameEventContent) + err, nameEvent := event.New( + newRoom.Id, + foundUser.Id, + config.Homeserver, + time.Now().Unix(), + "m.room.name", + "", + string(nameContentBytes), + "", + ) + if err == nil { + err = event.CreateEvent(nameEvent, "") + } + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Error Event-Creation: %s", err)}); err != nil { + panic(err) + } + return + } + topicEventContent := event.TopicEventContent{ + Topic: newRoom.Topic, + } + topicContentBytes, _ := json.Marshal(topicEventContent) + err, topicEvent := event.New( + newRoom.Id, + foundUser.Id, + config.Homeserver, + time.Now().Unix(), + "m.room.topic", + "", + string(topicContentBytes), + "", + ) + if err == nil { + err = event.CreateEvent(topicEvent, "") + } + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Error Event-Creation: %s", err)}); err != nil { + panic(err) + } + return + } + + response := createRoomResponse{RoomId: newRoom.Id} + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(response); err != nil { + panic(err) + } +} + +func GetRoomMemberHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json; charset=UTF-8") + request := getRoomMemberRequest{} + errResponse := utils.CheckRequest(r) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + token, errResponse := utils.GetAccessToken(r) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + foundUser, err := user.ReadUserFromAccessToken(token) + if err != nil || foundUser == nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorCode: "M_UNKNOWN_TOKEN", ErrorMessage: fmt.Sprintf("%s", err)}); err != nil { + panic(err) + } + return + } + decoder := json.NewDecoder(r.Body) + err = decoder.Decode(&request) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Could not parse JSON: %s", err)}); err != nil { + panic(err) + } + return + } + + vars := mux.Vars(r) + roomId := vars["roomId"] + if roomId == "" { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: "Missing Params"}); err != nil { + panic(err) + } + return + } + event.ReadStateEvents(roomId, "m.room.member") + + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode("Not Implemented"); err != nil { + panic(err) + } +} + +func JoinRoomUserHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json; charset=UTF-8") + request := joinRoomUserRequest{} + errResponse := utils.CheckRequest(r) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + token, errResponse := utils.GetAccessToken(r) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + foundUser, err := user.ReadUserFromAccessToken(token) + if err != nil || foundUser == nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorCode: "M_UNKNOWN_TOKEN", ErrorMessage: fmt.Sprintf("%s", err)}); err != nil { + panic(err) + } + return + } + decoder := json.NewDecoder(r.Body) + err = decoder.Decode(&request) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Could not parse JSON: %s", err)}); err != nil { + panic(err) + } + return + } + vars := mux.Vars(r) + roomId := vars["roomId"] + if roomId == "" { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: "Missing Parameter"}); err != nil { + panic(err) + } + return + } + foundRoom, err := ReadRoom(roomId) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}); err != nil { + panic(err) + } + return + } + var joinEvent *event.Event + if foundRoom == nil { + memberEventContent := event.MemberEventContent{ + DisplayName: foundUser.Name, + IsDirect: true, + Membership: "join", + } + memberEventContentBytes, _ := json.Marshal(memberEventContent) + err, memberEvent := event.New( + roomId, + foundUser.Id, + config.Homeserver, + time.Now().Unix(), + "m.room.member", + foundUser.Id, + string(memberEventContentBytes), + "", + ) + if err == nil { + err = event.CreateEvent(memberEvent, "") + } + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Error Event-Creation: %s", err)}); err != nil { + panic(err) + } + return + } + httpString := "https" + server := strings.Split(roomId, ":")[1] + requestUrl := fmt.Sprintf("%s://%s/_matrix/federation/v1/make_join/%s/%s", httpString, server, roomId, foundUser.Id) + res, err := http.Get(requestUrl) + makeJoinRes := makeJoinResponse{} + decoder = json.NewDecoder(res.Body) + err = decoder.Decode(&makeJoinRes) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Could not parse JSON: %s", err)}); err != nil { + panic(err) + } + return + } + err = CreateRoom(&Room{Id: roomId, Version: makeJoinRes.RoomVersion}) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}); err != nil { + panic(err) + } + return + } + err, joinEvent = event.New( + roomId, + makeJoinRes.Event.Sender, + makeJoinRes.Event.Origin, + makeJoinRes.Event.Timestamp, + makeJoinRes.Event.EventType, + makeJoinRes.Event.StateKey, + makeJoinRes.Event.Content, + "", + ) + requestUrl = fmt.Sprintf("%s://%s/_matrix/federation/v2/send_join/%s/%s", httpString, server, roomId, joinEvent.Id) + reqBody, err := json.Marshal(joinEvent) + if err != nil { + return + } + client := &http.Client{} + req, err := http.NewRequest(http.MethodPut, requestUrl, bytes.NewBuffer(reqBody)) + if err != nil { + return + } + res, err = client.Do(req) + joinRes := joinRoomServerResponse{} + decoder = json.NewDecoder(res.Body) + err = decoder.Decode(&joinRes) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Could not parse JSON: %s", err)}); err != nil { + panic(err) + } + return + } + err = event.HandleEvents(joinRes.State) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Error Handling Events: %s", err)}); err != nil { + panic(err) + } + return + } + } else { + memberEventContent := event.MemberEventContent{ + DisplayName: foundUser.Name, + IsDirect: true, + Membership: "join", + } + memberEventContentBytes, _ := json.Marshal(memberEventContent) + err, joinEvent = event.New( + roomId, + foundUser.Id, + config.Homeserver, + time.Now().Unix(), + "m.room.member", + foundUser.Id, + string(memberEventContentBytes), + "", + ) + } + err, txnId := utils.CreateUUID() + err = event.CreateEvent(joinEvent, txnId) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}); err != nil { + panic(err) + } + return + } + transaction := &event.Transaction{ + Id: txnId, + Origin: config.Homeserver, + Timestamp: time.Now().Unix(), + PDUS: []*event.Event{joinEvent}, + } + servers, err := event.ReadServers(roomId) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}); err != nil { + panic(err) + } + return + } + for _, server := range servers { + //if server != config.Homeserver { + operation := func() error { + return event.SendTransaction(transaction, server) + } + notify := func(err error, duration time.Duration) { + log.Printf("Error Sending Transaction, retrying in %ss: %s", duration/1000000000, err) + } + go backoff.RetryNotify(operation, backoff.NewExponentialBackOff(), notify) + //} + } + err = CreateRoomMember(roomId, foundUser.Id) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}); err != nil { + panic(err) + } + return + } + response := joinRoomUserResponse{RoomId: roomId} + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(response); err != nil { + panic(err) + } +} + +func GetPrepInfoToJoinHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json; charset=UTF-8") + request := makeJoinRequest{} + errResponse := utils.CheckRequest(r) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&request) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Could not parse JSON: %s", err)}); err != nil { + panic(err) + } + return + } + vars := mux.Vars(r) + roomId := vars["roomId"] + userId := vars["userId"] + if roomId == "" || userId == "" { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: "Missing Parameter"}); err != nil { + panic(err) + } + return + } + homeserver := strings.Split(userId, ":") + if len(homeserver) <= 1 { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: "Missing Homeserver in UserId"}); err != nil { + panic(err) + } + return + } + foundRoom, err := ReadRoom(roomId) + if err != nil || foundRoom == nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorCode: "MISSING_ROOM", ErrorMessage: fmt.Sprintf("%s", err)}); err != nil { + panic(err) + } + return + } + + memberEventContent := event.MemberEventContent{ + Membership: "join", + } + + memberEventContentBytes, err := json.Marshal(memberEventContent) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Could not parse JSON: %s", err)}); err != nil { + panic(err) + } + return + } + + response := makeJoinResponse{ + RoomVersion: foundRoom.Version, + Event: event.Event{ + Sender: userId, + Origin: homeserver[1], + Timestamp: time.Now().Unix(), + EventType: "m.room.member", + StateKey: userId, + Content: string(memberEventContentBytes), + }, + } + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(response); err != nil { + panic(err) + } +} + +// TODO: TEST +func JoinRoomServerHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json; charset=UTF-8") + request := event.Event{} + errResponse := utils.CheckRequest(r) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&request) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Could not parse JSON: %s", err)}); err != nil { + panic(err) + } + return + } + vars := mux.Vars(r) + roomId := vars["roomId"] + eventId := vars["eventId"] + if roomId == "" || eventId == "" { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: "Missing Parameter"}); err != nil { + panic(err) + } + return + } + foundRoom, err := ReadRoom(roomId) + if err != nil || foundRoom == nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorCode: "MISSING_ROOM", ErrorMessage: fmt.Sprintf("%s", err)}); err != nil { + panic(err) + } + return + } + request.RoomId = roomId + request.Id = eventId + + memberEventContent := event.MemberEventContent{} + err = json.Unmarshal([]byte(request.Content), &memberEventContent) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Could not parse JSON: %s", err)}); err != nil { + panic(err) + } + return + } + if memberEventContent.Membership != "join" { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: "Wrong Membership"}); err != nil { + panic(err) + } + return + } + if err == nil { + err = event.CreateEvent(&request, "") + } + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Error Event-Creation: %s", err)}); err != nil { + panic(err) + } + return + } + CreateRoomMember(roomId, request.StateKey) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}); err != nil { + panic(err) + } + return + } + authChain, err := event.GetAuthChain(&request) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Error Creating Auth Chain: %s", err)}); err != nil { + panic(err) + } + return + } + stateEvents, err := event.ReadStateEventsFromRoom(roomId) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}); err != nil { + panic(err) + } + return + } + response := joinRoomServerResponse{ + Origin: config.Homeserver, + AuthChain: authChain, + State: stateEvents, + } + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(response); err != nil { + panic(err) + } +} + +func generateRoomId(id string) string { + return fmt.Sprintf("!%s:%s", id, config.Homeserver) +} diff --git a/entities/room/roomDatabaseConnector.go b/entities/room/roomDatabaseConnector.go index a62db57..7102774 100644 --- a/entities/room/roomDatabaseConnector.go +++ b/entities/room/roomDatabaseConnector.go @@ -7,11 +7,11 @@ import ( "nutfactory.org/Matrix/utils/database" ) -func CreateRoom(room *Room, userId string) (err error) { +func CreateRoom(room *Room) (err error) { sqlStmt := fmt.Sprintf(`INSERT INTO room - (id, version) + (id, version, visibility, name, topic, isDirect, federated) VALUES - (?, ?)`) + (?, ?, ?, ?, ?, ?, ?)`) tx, err := database.DB.Begin() if err != nil { @@ -24,42 +24,31 @@ func CreateRoom(room *Room, userId string) (err error) { } defer stmt.Close() - _, err = stmt.Exec(room.Id, room.Version) + _, err = stmt.Exec( + room.Id, + room.Version, + room.Visibility, + room.Name, + room.Topic, + room.IsDirect, + room.Federated, + ) if err != nil { return } tx.Commit() - err = CreateRoomMember(room.Id, userId) + for _, userId := range room.Members { + err = CreateRoomMember(room.Id, userId) + } return } func CreateRoomMember(roomId string, userId string) (err error) { - sqlStmt := fmt.Sprintf(`INSERT INTO roomMember - (roomId, userId) - VALUES - (?, ?)`) - - tx, err := database.DB.Begin() - if err != nil { - return - } - - stmt, err := tx.Prepare(sqlStmt) - if err != nil { - return - } - defer stmt.Close() - - _, err = stmt.Exec(roomId, userId) - if err != nil { - return - } - tx.Commit() - return + return event.CreateRoomMember(roomId, userId) } func ReadRoom(id string) (foundRoom *Room, err error) { - queryStmt := fmt.Sprintf(`SELECT id, version + queryStmt := fmt.Sprintf(`SELECT id, version, visibility, name, topic, isDirect, federated FROM room WHERE id = '%s'`, id) @@ -72,11 +61,19 @@ func ReadRoom(id string) (foundRoom *Room, err error) { if rows.Next() { foundRoom = &Room{} - err = rows.Scan(&foundRoom.Id, &foundRoom.Version) + err = rows.Scan( + &foundRoom.Id, + &foundRoom.Version, + &foundRoom.Visibility, + &foundRoom.Name, + &foundRoom.Topic, + &foundRoom.IsDirect, + &foundRoom.Federated, + ) if err != nil { return } - foundRoom.Messages, err = event.ReadEventsFromRoom(foundRoom.Id) + foundRoom.Events, err = event.ReadEventsFromRoom(foundRoom.Id) if err != nil { return } @@ -90,34 +87,17 @@ func ReadRoom(id string) (foundRoom *Room, err error) { } func ReadRoomMembers(roomId string) (roomMembers []string, err error) { - queryStmt := fmt.Sprintf(`SELECT userId - FROM roomMember - WHERE roomId = '%s'`, roomId) - - rows, err := database.DB.Query(queryStmt) - if err != nil { - return - } - - defer rows.Close() - - roomMembers = []string{} - - for rows.Next() { - var foundUser string - err = rows.Scan(&foundUser) - if err != nil { - return - } - roomMembers = append(roomMembers, foundUser) - } - - return + return event.ReadRoomMembers(roomId) } func UpdateRoom(room *Room) (err error) { sqlStmt := fmt.Sprintf(`UPDATE room SET - version = ? + version = ?, + visibility = ?, + name = ?, + topic = ?, + isDirect = ?, + federated = ? WHERE id = ?`) tx, err := database.DB.Begin() @@ -131,7 +111,15 @@ func UpdateRoom(room *Room) (err error) { } defer stmt.Close() - _, err = stmt.Exec(room.Version, room.Id) + _, err = stmt.Exec( + room.Version, + room.Visibility, + room.Name, + room.Topic, + room.IsDirect, + room.Federated, + room.Id, + ) if err != nil { return } diff --git a/entities/transaction/transaction.go b/entities/transaction/transaction.go deleted file mode 100644 index 3c0cb46..0000000 --- a/entities/transaction/transaction.go +++ /dev/null @@ -1,13 +0,0 @@ -package transaction - -import ( - "nutfactory.org/Matrix/entities/event" -) - -type Transaction struct { - Id string `json:"id,omitempty"` - Origin string `json:"origin,omitempty"` - Timestamp int `json:"timestamp,omitempty"` - PDUS map[string]*event.Event `json:"pdus,omitempty"` - //EDUS []event.EDU `json:"edus,omitempty"` -} diff --git a/entities/transaction/transactionController.go b/entities/transaction/transactionController.go deleted file mode 100644 index fd8610a..0000000 --- a/entities/transaction/transactionController.go +++ /dev/null @@ -1,5 +0,0 @@ -package transaction - -func New() (transaction *Transaction) { - return -} diff --git a/entities/user/user.go b/entities/user/user.go index 1f455e8..eefa559 100644 --- a/entities/user/user.go +++ b/entities/user/user.go @@ -11,6 +11,14 @@ type User struct { Devices map[string]*device.Device `json:"devices,omitempty"` } +type availableRequest struct { + Username string `json:"username,omitempty"` +} + +type availableResponse struct { + Available bool `json:"available,omitempty"` +} + type registerRequest struct { Auth authentificationData `json:"auth,omitempty"` Username string `json:"username,omitempty"` @@ -58,12 +66,6 @@ type changePasswordRequest struct { Auth authentificationData } -type errorResponse struct { - ErrorCode string `json:"errcode,omitempty"` - ErrorMessage string `json:"error,omitempty"` - RetryTime int `json:"retry_after_ms,omitempty"` -} - type identifier struct { IdentifierType string `json:"type,omitempty"` User string `json:"user,omitempty"` diff --git a/entities/user/userController.go b/entities/user/userController.go index 1199026..ffa9e86 100644 --- a/entities/user/userController.go +++ b/entities/user/userController.go @@ -2,36 +2,72 @@ package user import ( "encoding/json" - "log" + "fmt" "net/http" + "nutfactory.org/Matrix/config" "nutfactory.org/Matrix/entities/device" "nutfactory.org/Matrix/utils" ) -func New(id string, name, string, password string, devices map[string]*device.Device) (err error, newUser *User) { +func New(username string, name string, password string) (err error, newUser *User) { err, hashedPassword := utils.Hash([]byte(password)) if err != nil { return } + id := generateUserId(username) newUser = &User{ Id: id, Name: name, - Password: password, - Devices: devices, + Password: hashedPassword, + Devices: make(map[string]*device.Device), } return } -func CheckUsernameAvailability(w http.ResponseWriter, r *http.Request) { +func CheckUsernameAvailabilityHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json; charset=UTF-8") + request := availableRequest{} + errResponse := utils.CheckRequest(r) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&request) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Could not parse JSON: %s", err)}); err != nil { + panic(err) + } + return + } + userId := generateUserId(request.Username) + foundUser, err := ReadUser(userId) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}); err != nil { + panic(err) + } + return + } + if foundUser != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorCode: "M_USER_IN_USE", ErrorMessage: "Username already in use"}); err != nil { + panic(err) + } + return + } w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode("Test"); err != nil { + if err := json.NewEncoder(w).Encode(availableResponse{Available: true}); err != nil { panic(err) } } -func Register(w http.ResponseWriter, r *http.Request) { +func RegisterHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json; charset=UTF-8") request := registerRequest{} errResponse := utils.CheckRequest(r) @@ -46,7 +82,7 @@ func Register(w http.ResponseWriter, r *http.Request) { err := decoder.Decode(&request) if err != nil { w.WriteHeader(http.StatusBadRequest) - if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: "Could not parse JSON"}); err != nil { + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Could not parse JSON: %s", err)}); err != nil { panic(err) } return @@ -59,12 +95,7 @@ func Register(w http.ResponseWriter, r *http.Request) { } return } - // TODO: Use New Function - newUser := &User{ - Id: request.Username, - Name: request.Username, - Password: request.Password, - } + err, newUser := New(request.Username, request.Username, request.Password) foundUser, err := ReadUser(newUser.Id) if foundUser != nil { w.WriteHeader(http.StatusBadRequest) @@ -76,40 +107,18 @@ func Register(w http.ResponseWriter, r *http.Request) { err = CreateUser(newUser) if err != nil { w.WriteHeader(http.StatusBadRequest) - if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: "Database Error"}); err != nil { + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}); err != nil { panic(err) } return } - userDevice, err := device.ReadDevice(request.DeviceId) - if userDevice != nil { - err = userDevice.RenewAccesToken() - if err != nil { - log.Fatalf("Unable to renew AccesToken: %s", err) - return - } - err = device.UpdateDevice(userDevice, newUser.Id) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: "Database Error"}); err != nil { - panic(err) - } - return - } - } else { - err, userDevice = device.New(request.DeviceName) - if err != nil { - log.Fatalf("Unable to create device: %s", err) - return - } - err = device.CreateDevice(userDevice, newUser.Id) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: "Database Error"}); err != nil { - panic(err) - } - return + userDevice, errResponse := createUserDevice(request.DeviceId, request.DeviceName, newUser.Id) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) } + return } response := registerResponse{ UserId: newUser.Id, @@ -122,7 +131,7 @@ func Register(w http.ResponseWriter, r *http.Request) { } } -func Login(w http.ResponseWriter, r *http.Request) { +func LoginHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json; charset=UTF-8") request := loginRequest{} errResponse := utils.CheckRequest(r) @@ -137,7 +146,7 @@ func Login(w http.ResponseWriter, r *http.Request) { err := decoder.Decode(&request) if err != nil { w.WriteHeader(http.StatusBadRequest) - if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: "Could not parse JSON"}); err != nil { + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Could not parse JSON: %s", err)}); err != nil { panic(err) } return @@ -150,40 +159,91 @@ func Login(w http.ResponseWriter, r *http.Request) { } return } - + if request.Identifier.IdentifierType != "m.id.user" && request.Identifier.User == "" { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: "Username missing"}); err != nil { + panic(err) + } + return + } + userId := generateUserId(request.Identifier.User) + foundUser, err := ReadUser(userId) + if err != nil || foundUser == nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorCode: "M_FORBIDDEN"}); err != nil { + panic(err) + } + return + } + err, hashedPassword := utils.Hash([]byte(request.Password)) + if err != nil || foundUser.Password != hashedPassword { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorCode: "M_FORBIDDEN"}); err != nil { + panic(err) + } + return + } + userDevice, errResponse := createUserDevice(request.DeviceId, request.DeviceName, request.Identifier.User) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + response := loginResponse{ + UserId: foundUser.Id, + AccessToken: userDevice.AccessToken, + DeviceId: userDevice.Id, + } + response.DiscoveryInfo.Homeserver.BaseUrl = config.Homeserver w.Header().Set("Content-Type", "application/json; charset=UTF-8") w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode("Test"); err != nil { + if err := json.NewEncoder(w).Encode(response); err != nil { panic(err) } } -func Logout(w http.ResponseWriter, r *http.Request) { +func LogoutHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json; charset=UTF-8") + errResponse := utils.CheckRequest(r) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + accessToken, errResponse := utils.GetAccessToken(r) + if errResponse != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(errResponse); err != nil { + panic(err) + } + return + } + foundDevice, err := device.ReadDeviceFromAccessToken(accessToken) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}); err != nil { + panic(err) + } + return + } + foundDevice.AccessToken = "" + err = device.UpdateDevice(foundDevice) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)}); err != nil { + panic(err) + } + return + } w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode("Test"); err != nil { - panic(err) - } -} - -func Deactivate(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json; charset=UTF-8") - w.WriteHeader(http.StatusBadRequest) - if err := json.NewEncoder(w).Encode("Not Implemented"); err != nil { - panic(err) - } -} - -func ChangePassword(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json; charset=UTF-8") - w.WriteHeader(http.StatusBadRequest) - if err := json.NewEncoder(w).Encode("Not Implemented"); err != nil { - panic(err) - } } //TODO: Check if necessary -func Sync(w http.ResponseWriter, r *http.Request) { +func SyncHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json; charset=UTF-8") w.WriteHeader(http.StatusBadRequest) if err := json.NewEncoder(w).Encode("Not Implemented"); err != nil { @@ -193,8 +253,47 @@ func Sync(w http.ResponseWriter, r *http.Request) { func checkLoginType(loginType string) (errResponse *utils.ErrorResponse) { if loginType != "m.login.password" { - errResponse = &utils.ErrorResponse{ErrorCode: "M_FORBIDDEN", ErrorMessage: "Unsupported Auth Type"} + errResponse = &utils.ErrorResponse{ErrorCode: "M_UNKNOWN", ErrorMessage: "Bad login type."} return } return } + +func generateUserId(username string) string { + return fmt.Sprintf("@%s:%s", username, config.Homeserver) +} + +func createUserDevice(id string, name string, userId string) (userDevice *device.Device, errResponse *utils.ErrorResponse) { + userDevice, err := device.ReadDevice(id) + if err != nil { + errResponse = &utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)} + return + } + if userDevice != nil { + err = userDevice.RenewAccesToken() + if name != "" { + userDevice.Name = name + } + if err != nil { + errResponse = &utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Unable to renew AccesToken: %s", err)} + return + } + err = device.UpdateDevice(userDevice) + if err != nil { + errResponse = &utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)} + return + } + } else { + err, userDevice = device.New(name) + if err != nil { + errResponse = &utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Unable to create device: %s", err)} + return + } + err = device.CreateDevice(userDevice, userId) + if err != nil { + errResponse = &utils.ErrorResponse{ErrorMessage: fmt.Sprintf("Database Error: %s", err)} + return + } + } + return +} diff --git a/entities/user/userDatabaseConnector.go b/entities/user/userDatabaseConnector.go index 5436f98..2d65035 100644 --- a/entities/user/userDatabaseConnector.go +++ b/entities/user/userDatabaseConnector.go @@ -56,6 +56,31 @@ func ReadUser(id string) (foundUser *User, err error) { return } +func ReadUserFromAccessToken(accessToken string) (foundUser *User, err error) { + queryStmt := fmt.Sprintf(`SELECT u.id, u.name, u.password + FROM user as u + join device as d on u.id = d.userId + WHERE d.accessToken = '%s'`, accessToken) + + rows, err := database.DB.Query(queryStmt) + if err != nil { + return + } + + defer rows.Close() + + if rows.Next() { + foundUser = &User{} + err = rows.Scan(&foundUser.Id, &foundUser.Name, &foundUser.Password) + if err != nil { + return + } + foundUser.Devices, err = device.ReadDevicesForUser(foundUser.Id) + } + + return +} + func UpdateUser(user *User) (err error) { sqlStmt := fmt.Sprintf(`UPDATE user SET name = ?, diff --git a/go.mod b/go.mod index 1e0233e..f8fbd6f 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module nutfactory.org/Matrix go 1.14 require ( + github.com/cenkalti/backoff/v4 v4.1.0 github.com/gorilla/mux v1.8.0 github.com/mattn/go-sqlite3 v1.14.3 ) diff --git a/go.sum b/go.sum index 1f2f40a..7653a8d 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,33 @@ +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/cenkalti/backoff v1.1.0 h1:QnvVp8ikKCDWOsFheytRCoYWYPO/ObCTBGxT19Hc+yE= +github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= +github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= +github.com/cenkalti/backoff/v4 v4.1.0 h1:c8LkOFQTzuO0WBM/ae5HdGQuZPfPxp7lqBRwQRm4fSc= +github.com/cenkalti/backoff/v4 v4.1.0/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/mattn/go-sqlite3 v1.14.3 h1:j7a/xn1U6TKA/PHHxqZuzh64CdtRc7rU9M+AvkOl5bA= github.com/mattn/go-sqlite3 v1.14.3/go.mod h1:WVKg1VTActs4Qso6iwGbiFih2UIHo0ENGwNd0Lj+XmI= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d h1:/iIZNFGxc/a7C3yWjGcnboV+Tkc7mxr+p6fDztwoxuM= +golang.org/x/tools v0.0.0-20191130070609-6e064ea0cf2d/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +honnef.co/go/tools v0.0.1-2020.1.5 h1:nI5egYTGJakVyOryqLs1cQO5dO0ksin5XXs2pspk75k= +honnef.co/go/tools v0.0.1-2020.1.5/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= diff --git a/main.go b/main.go index 02e4cc3..ec41e4f 100644 --- a/main.go +++ b/main.go @@ -1,11 +1,16 @@ package main import ( - "encoding/json" + "crypto/tls" "log" "net/http" "os" + "nutfactory.org/Matrix/config" + "nutfactory.org/Matrix/entities/device" + "nutfactory.org/Matrix/entities/event" + "nutfactory.org/Matrix/entities/general" + "nutfactory.org/Matrix/entities/room" "nutfactory.org/Matrix/entities/user" "nutfactory.org/Matrix/utils/database" "nutfactory.org/Matrix/utils/router" @@ -14,73 +19,65 @@ import ( var keyPath = "./ssl.key" var certPath = "./ssl.crt" -//var htmlPath = "./html/" - var routes = router.Routes{ + // General - router.Route{"ResolveServerName", "GET", "/.well-known/matrix/server", Test}, - router.Route{"GetServerImplementation", "GET", "/_matrix/federation/v1/version", Test}, + router.Route{"ResolveServerName", "GET", "/.well-known/matrix/server", general.ResolveServerName}, + router.Route{"GetServerImplementation", "GET", "/_matrix/federation/v1/version", general.GetServerImplementation}, + router.Route{"Reset", "GET", "/reset", general.Reset}, // Keys - router.Route{"GetSigningKey", "GET", "/_matrix/key/v2/server/{keyId}", Test}, - router.Route{"GetSigningKeyFromServer", "GET", "/_matrix/key/v2/query/{serverName}/{keyId}", Test}, - router.Route{"GetSigningKeyFromMultipleServer", "GET", "/_matrix/key/v2/query", Test}, + router.Route{"GetSigningKey", "GET", "/_matrix/key/v2/server/{keyId}", device.GetServerSigningKeyHandler}, + router.Route{"GetSigningKey", "GET", "/_matrix/key/v2/server", device.GetServerSigningKeyHandler}, // Users - router.Route{"CheckUsernameAvailability", "GET", "/_matrix/client/r0/register/available", user.CheckUsernameAvailability}, - router.Route{"Register", "POST", "/_matrix/client/r0/register", user.Register}, - router.Route{"Login", "POST", "/_matrix/client/r0/login", user.Login}, - router.Route{"Logout", "POST", "/_matrix/client/r0/logout", user.Logout}, - router.Route{"Deactivate", "POST", "/_matrix/client/r0/account/deactivate", user.Deactivate}, - router.Route{"ChangePassword", "POST", "/_matrix/client/r0/account/password", user.ChangePassword}, - router.Route{"Sync", "GET", "/_matrix/client/r0/sync", user.Sync}, + router.Route{"CheckUsernameAvailability", "GET", "/_matrix/client/r0/register/available", user.CheckUsernameAvailabilityHandler}, + router.Route{"Register", "POST", "/_matrix/client/r0/register", user.RegisterHandler}, + router.Route{"Login", "POST", "/_matrix/client/r0/login", user.LoginHandler}, + router.Route{"Logout", "POST", "/_matrix/client/r0/logout", user.LogoutHandler}, + router.Route{"Sync", "GET", "/_matrix/client/r0/sync", user.SyncHandler}, // Rooms - router.Route{"CreateRoom", "POST", "/_matrix/client/r0/createRoom", Test}, - router.Route{"GetRoomMembers", "GET", "/_matrix/client/r0/rooms/{roomId}/members", Test}, - router.Route{"JoinRoomUser", "POST", "/_matrix/client/r0/rooms/{roomId}/join", Test}, - router.Route{"LeaveRoomUser", "POST", "/_matrix/client/r0/rooms/{roomId}/leave", Test}, + router.Route{"CreateRoom", "POST", "/_matrix/client/r0/createRoom", room.CreateRoomHandler}, + router.Route{"GetRoomMembers", "GET", "/_matrix/client/r0/rooms/{roomId}/members", room.GetRoomMemberHandler}, + router.Route{"JoinRoomUser", "POST", "/_matrix/client/r0/rooms/{roomId}/join", room.JoinRoomUserHandler}, - router.Route{"GetPrepInfoToJoin", "GET", "/_matrix/federation/v1/make_join/{roomId}/{userId}", Test}, - router.Route{"JoinRoomServer", "PUT", "/_matrix/federation/v2/send_join/{roomId}/{eventId}", Test}, - router.Route{"GetPrepInfoToLeave", "GET", "/_matrix/federation/v1/make_leave/{roomId}/{userId}", Test}, - router.Route{"LeaveRoomServer", "PUT", "/_matrix/federation/v2/send_leave/{roomId}/{eventId}", Test}, + router.Route{"GetPrepInfoToJoin", "GET", "/_matrix/federation/v1/make_join/{roomId}/{userId}", room.GetPrepInfoToJoinHandler}, + router.Route{"JoinRoomServer", "PUT", "/_matrix/federation/v2/send_join/{roomId}/{eventId}", room.JoinRoomServerHandler}, // Events - router.Route{"CreateEvent", "PUT", "/_matrix/client/r0/rooms/{roomId}/send/{eventType}/{txnId}", Test}, - router.Route{"CreateStateEvent", "PUT", "/_matrix/client/r0/rooms/{roomId}/state/{eventType}/{stateKey}", Test}, - router.Route{"ChangeEvent", "PUT", "/_matrix/client/r0/rooms/{roomId}/redact/{eventId}/{txnId}", Test}, - router.Route{"GetEvents", "GET", "/_matrix/client/r0/rooms/{roomId}/messages", Test}, - router.Route{"GetStateEventsUser", "GET", "/_matrix/client/r0/rooms/{roomId}/state", Test}, - router.Route{"GetEventUser", "GET", "/_matrix/client/r0/rooms/{roomId}/event/{eventId}", Test}, - router.Route{"GetStateEvent", "GET", "/_matrix/client/r0/rooms/{roomId}/state/{eventType}/{stateKey}", Test}, + router.Route{"CreateEvent", "PUT", "/_matrix/client/r0/rooms/{roomId}/send/{eventType}/{txnId}", event.SendMessageHandler}, + router.Route{"CreateStateEvent", "PUT", "/_matrix/client/r0/rooms/{roomId}/state/{eventType}/{stateKey}", event.CreateStateEventHandler}, + router.Route{"GetEventUser", "GET", "/_matrix/client/r0/rooms/{roomId}/event/{eventId}", event.GetEventUserHandler}, + router.Route{"GetStateEvent", "GET", "/_matrix/client/r0/rooms/{roomId}/state/{eventType}/{stateKey}", event.GetStateEventHandler}, - router.Route{"GetStateEventsServer", "GET", "/_matrix/federation/v1/state/{roomId}", Test}, - router.Route{"GetEventServer", "GET", "/_matrix/federation/v1/event/{eventId}", Test}, - router.Route{"SyncEventsServer", "PUT", "/_matrix/federation/v1/send/{txnId}", Test}, - router.Route{"Backfill", "GET", "/_matrix/federation/v1/backfill/{roomId}", Test}, - router.Route{"GetMissingEvents", "POST", "/_matrix/federation/v1/get_missing_events/{roomId}", Test}, -} - -func Test(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json; charset=UTF-8") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode("Test"); err != nil { - panic(err) - } + router.Route{"SyncEventsServer", "PUT", "/_matrix/federation/v1/send/{txnId}", event.SyncEventsServerHandler}, + router.Route{"Backfill", "GET", "/_matrix/federation/v1/backfill/{roomId}", event.BackfillHandler}, } func main() { - // TODO: Remove later + // TODO: Change to something variable --> cli + // TODO: Implement Message Counter + http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + config.Homeserver = "localhost" + config.Port = "80" + if err := device.InitServerSigningKey(); err != nil { + log.Fatal(err) + } + config.VerifyKeys = make(map[string]map[string][]byte) os.Remove("sqlite.db") - _ = database.InitDB("sqlite.db") + if err := database.InitDB("sqlite.db"); err != nil { + log.Fatal(err) + } defer database.DB.Close() + // TODO: Set Default Config Params here + config.HttpString = "https" + router := router.NewRouter(routes) - //router.PathPrefix("/").Handler(http.FileServer(http.Dir(htmlPath))) - + // TODO: Serve on Port 443 and 80 without Redirect httpErr := http.ListenAndServeTLS(":443", certPath, keyPath, router) if httpErr != nil { log.Fatal(httpErr) diff --git a/sqlite.db b/sqlite.db index c4f2a52..3f148af 100644 Binary files a/sqlite.db and b/sqlite.db differ diff --git a/utils/database/databaseConnector.go b/utils/database/databaseConnector.go index d777b90..58dea63 100644 --- a/utils/database/databaseConnector.go +++ b/utils/database/databaseConnector.go @@ -34,7 +34,8 @@ func initDeviceTable() (err error) { log.Printf("Init Device Table") statement, err := DB.Prepare(`CREATE TABLE IF NOT EXISTS device ( id TEXT PRIMARY KEY, - name TEXT, + name TEXT, + accessToken TEXT, userId TEXT )`) if err != nil { @@ -137,10 +138,35 @@ func initEventTable() (err error) { id TEXT PRIMARY KEY, roomId TEXT, txnId TEXT, + sender TEXT, + origin TEXT, + timestamp INTEGER, eventType TEXT, - content TEXT, - parentId TEXT, - depth INTEGER + stateKey TEXT, + content TEXT, + depth INTEGER, + hash TEXT, + signature TEXT + )`) + if err != nil { + return + } + statement.Exec() + + statement, err = DB.Prepare(`CREATE TABLE IF NOT EXISTS parent ( + eventId TEXT, + parentId TEXT, + PRIMARY KEY (eventId, parentId) + )`) + if err != nil { + return + } + statement.Exec() + + statement, err = DB.Prepare(`CREATE TABLE IF NOT EXISTS authEvent ( + eventId TEXT, + authEventId TEXT, + PRIMARY KEY (eventId, authEventId) )`) if err != nil { return @@ -202,7 +228,12 @@ func initRoomTable() (err error) { log.Printf("Init Room Table") statement, err := DB.Prepare(`CREATE TABLE IF NOT EXISTS room ( id TEXT PRIMARY KEY, - version TEXT + version TEXT, + visibility TEXT, + name TEXT, + topic TEXT, + isDirect INT, + federated INT )`) if err != nil { return @@ -211,6 +242,7 @@ func initRoomTable() (err error) { statement, err = DB.Prepare(`CREATE TABLE IF NOT EXISTS roomMember ( userId TEXT, roomId TEXT, + server TEXT, PRIMARY KEY (userId, roomId) )`) if err != nil { diff --git a/utils/encryptionService.go b/utils/encryptionService.go index 8b89ddf..4997028 100644 --- a/utils/encryptionService.go +++ b/utils/encryptionService.go @@ -1,11 +1,14 @@ package utils import ( + "crypto/ed25519" "crypto/rand" "crypto/sha256" "encoding/base64" "fmt" "log" + + "nutfactory.org/Matrix/config" ) func CreateToken() (err error, token string) { @@ -15,7 +18,7 @@ func CreateToken() (err error, token string) { log.Fatal(err) return } - token = string(b) + token = fmt.Sprintf("%x", b) return } @@ -26,7 +29,7 @@ func CreateUUID() (err error, uuid string) { log.Fatal(err) return } - uuid = fmt.Sprintf("%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:]) + uuid = fmt.Sprintf("%x_%x_%x_%x_%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:]) return } @@ -40,3 +43,23 @@ func Hash(s []byte) (err error, hashString string) { hashString = base64.StdEncoding.EncodeToString(hash) return } + +func GenerateKeyPair() (publicKey ed25519.PublicKey, privateKey ed25519.PrivateKey, err error) { + publicKey, privateKey, err = ed25519.GenerateKey(nil) + return +} + +func Sign(message []byte) []byte { + return ed25519.Sign(config.PrivateKey, message) +} + +func SignContent(content []byte) (signatures map[string]map[string]string) { + signatures = make(map[string]map[string]string) + signatures[config.Homeserver] = make(map[string]string) + signatures[config.Homeserver][config.KeyId] = string(Sign(content)) + return +} + +func VerifySignature(publicKey []byte, message []byte, signature []byte) bool { + return ed25519.Verify(publicKey, message, signature) +} diff --git a/utils/requestChecker.go b/utils/requestChecker.go index 7c3c530..32366ea 100644 --- a/utils/requestChecker.go +++ b/utils/requestChecker.go @@ -1,11 +1,24 @@ package utils import ( + "bytes" "encoding/json" + "fmt" "net/http" "strings" + + "nutfactory.org/Matrix/config" ) +type RequestSummary struct { + Method string `json:"method,omitempty"` + Uri string `json:"uri,omitempty"` + Origin string `json:"origin,omitempty"` + Destination string `json:"destination,omitempty"` + Content string `json:"content,omitempty"` + Signatures map[string]map[string]string `json:"signatures,omitempty"` +} + type ErrorResponse struct { ErrorCode string `json:"errcode,omitempty"` ErrorMessage string `json:"error,omitempty"` @@ -19,6 +32,72 @@ func CheckRequest(r *http.Request) (response *ErrorResponse) { return } +func CheckAuthHeader(r *http.Request) (response *ErrorResponse) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" || !strings.Contains(authHeader, "X-Matrix") { + response = &ErrorResponse{ErrorMessage: "Missing Authorization Header"} + return + } + keys := strings.Split(authHeader, ",") + origin := strings.Split(keys[0], "=")[1] + if !strings.Contains(keys[2], "ed25519") { + response = &ErrorResponse{ErrorMessage: "Missing ed25519 Signature Key"} + return + } + key := strings.Split(strings.Replace(strings.Split(keys[2], "=")[1], "\"", "", 2), ":")[1] + signature := strings.Replace(strings.Split(keys[2], "=")[1], "\"", "", 2) + buf := new(bytes.Buffer) + buf.ReadFrom(r.Body) + content := buf.String() + requestSummary := RequestSummary{ + Method: r.Method, + Uri: r.RequestURI, + Origin: origin, + Destination: config.Homeserver, + Content: content, + } + requestSummaryString, err := json.Marshal(requestSummary) + if err != nil { + response = &ErrorResponse{ErrorMessage: "Error Creating Auth JSON String"} + return + } + correct := VerifySignature([]byte(key), requestSummaryString, []byte(signature)) + if !correct { + response = &ErrorResponse{ErrorMessage: "Signature in Auth Header is incorrect"} + return + } + return +} + +func CreateAuthHeader(method string, uri string, destination string, content string) (authHeader string, err error) { + requestSummary := RequestSummary{ + Method: method, + Uri: uri, + Origin: config.Homeserver, + Destination: destination, + Content: content, + } + SigningContent, err := json.Marshal(requestSummary) + if err != nil { + return + } + authHeader = fmt.Sprintf("X-Matrix origin=%s,key=\"%s\",sig=\"%s\"", config.Homeserver, config.KeyId, Sign(SigningContent)) + return +} + +func GetAccessToken(r *http.Request) (token string, response *ErrorResponse) { + token = r.URL.Query().Get("access_token") + if token == "" { + token = r.Header.Get("Authorization") + if token == "" || !strings.Contains(token, "Bearer") { + response = &ErrorResponse{ErrorCode: "M_MISSING_TOKEN"} + } else { + token = strings.Split(token, " ")[1] + } + } + return +} + func IsJSONString(s string) bool { var js string return json.Unmarshal([]byte(s), &js) == nil