AMAZING CHANGEROOS

This commit is contained in:
Andrew Morgan 2018-06-12 23:32:45 +01:00
parent cc9cae60cb
commit bb408c53ad
5 changed files with 27 additions and 18 deletions

View file

@ -65,12 +65,6 @@ type Data struct {
func VerifyUserFromRequest( func VerifyUserFromRequest(
req *http.Request, data Data, req *http.Request, data Data,
) (*authtypes.Device, *util.JSONResponse) { ) (*authtypes.Device, *util.JSONResponse) {
// Try to find local user from device database
dev, devErr := verifyAccessToken(req, data.DeviceDB)
if devErr == nil {
return dev, nil
}
// Try to find the Application Service user // Try to find the Application Service user
token, err := extractAccessToken(req) token, err := extractAccessToken(req)
if err != nil { if err != nil {
@ -89,8 +83,8 @@ func VerifyUserFromRequest(
} }
} }
if appService != nil { userID := req.URL.Query().Get("user_id")
userID := req.URL.Query().Get("user_id") if appService != nil && userID != "" {
localpart, err := userutil.ParseUsernameParam(userID, nil) localpart, err := userutil.ParseUsernameParam(userID, nil)
if err != nil { if err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
@ -98,6 +92,7 @@ func VerifyUserFromRequest(
JSON: jsonerror.InvalidUsername(err.Error()), JSON: jsonerror.InvalidUsername(err.Error()),
} }
} }
fmt.Println("APPSERVICE MASQUERADING AS:", localpart)
// Verify that the user is registered // Verify that the user is registered
account, err := data.AccountDB.GetAccountByLocalpart(req.Context(), localpart) account, err := data.AccountDB.GetAccountByLocalpart(req.Context(), localpart)
@ -123,6 +118,13 @@ func VerifyUserFromRequest(
} }
} }
// Try to find local user from device database
dev, devErr := verifyAccessToken(req, data.DeviceDB)
if devErr == nil {
fmt.Println("Found local device:", dev)
return dev, nil
}
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusUnauthorized, Code: http.StatusUnauthorized,
JSON: jsonerror.UnknownToken("Unrecognized access token"), JSON: jsonerror.UnknownToken("Unrecognized access token"),

View file

@ -103,6 +103,7 @@ func (r joinRoomReq) joinRoomByID(roomID string) util.JSONResponse {
queryReq := api.QueryInvitesForUserRequest{ queryReq := api.QueryInvitesForUserRequest{
RoomID: roomID, TargetUserID: r.userID, RoomID: roomID, TargetUserID: r.userID,
} }
fmt.Println(queryReq)
var queryRes api.QueryInvitesForUserResponse var queryRes api.QueryInvitesForUserResponse
if err := r.queryAPI.QueryInvitesForUser(r.req.Context(), &queryReq, &queryRes); err != nil { if err := r.queryAPI.QueryInvitesForUser(r.req.Context(), &queryReq, &queryRes); err != nil {
return httputil.LogThenError(r.req, err) return httputil.LogThenError(r.req, err)

View file

@ -17,6 +17,7 @@ package routing
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"net/http" "net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -164,6 +165,7 @@ func loadProfile(
} }
var profile *authtypes.Profile var profile *authtypes.Profile
fmt.Println("Getting by localpart:", localpart)
if serverName == cfg.Matrix.ServerName { if serverName == cfg.Matrix.ServerName {
profile, err = accountDB.GetProfileByLocalpart(ctx, localpart) profile, err = accountDB.GetProfileByLocalpart(ctx, localpart)
} else { } else {

View file

@ -188,7 +188,7 @@ func validateUserName(username string) *util.JSONResponse {
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("User ID can only contain characters a-z, 0-9, or '_-./'"), JSON: jsonerror.InvalidUsername("User ID can only contain characters a-z, 0-9, or '_-./'"),
} }
} else if username[0] == '_' { // Regex checks its not a zero length string } else if username[0] == '_' && false { // Regex checks its not a zero length string
return &util.JSONResponse{ return &util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.InvalidUsername("User ID can't start with a '_'"), JSON: jsonerror.InvalidUsername("User ID can't start with a '_'"),
@ -293,6 +293,7 @@ func UsernameIsWithinApplicationServiceNamespace(
// Loop through given application service's namespaces and see if any match // Loop through given application service's namespaces and see if any match
for _, namespace := range appservice.NamespaceMap["users"] { for _, namespace := range appservice.NamespaceMap["users"] {
// AS namespaces are checked for validity in config // AS namespaces are checked for validity in config
fmt.Println("Checking", username, "against", namespace.RegexpObject)
if namespace.RegexpObject.MatchString(username) { if namespace.RegexpObject.MatchString(username) {
return true return true
} }
@ -357,7 +358,8 @@ func validateApplicationService(
} }
// Ensure the desired username is within at least one of the application service's namespaces. // Ensure the desired username is within at least one of the application service's namespaces.
if !UsernameIsWithinApplicationServiceNamespace(cfg, username, matchedApplicationService) { usernameWithID := "@" + username
if !UsernameIsWithinApplicationServiceNamespace(cfg, usernameWithID, matchedApplicationService) {
// If we didn't find any matches, return M_EXCLUSIVE // If we didn't find any matches, return M_EXCLUSIVE
return "", &util.JSONResponse{ return "", &util.JSONResponse{
Code: http.StatusUnauthorized, Code: http.StatusUnauthorized,
@ -419,7 +421,7 @@ func Register(
} }
// If no auth type is specified by the client, send back the list of available flows // If no auth type is specified by the client, send back the list of available flows
if r.Auth.Type == "" { if r.Auth.Type == "" && false {
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusUnauthorized, Code: http.StatusUnauthorized,
JSON: newUserInteractiveResponse(sessionID, JSON: newUserInteractiveResponse(sessionID,
@ -481,6 +483,7 @@ func handleRegistrationFlow(
return util.MessageResponse(http.StatusForbidden, "Registration has been disabled") return util.MessageResponse(http.StatusForbidden, "Registration has been disabled")
} }
fmt.Println("Tried?")
switch r.Auth.Type { switch r.Auth.Type {
case authtypes.LoginTypeRecaptcha: case authtypes.LoginTypeRecaptcha:
// Check given captcha response // Check given captcha response
@ -505,7 +508,8 @@ func handleRegistrationFlow(
// Add SharedSecret to the list of completed registration stages // Add SharedSecret to the list of completed registration stages
sessions.AddCompletedStage(sessionID, authtypes.LoginTypeSharedSecret) sessions.AddCompletedStage(sessionID, authtypes.LoginTypeSharedSecret)
case authtypes.LoginTypeApplicationService: default:
fmt.Println("You tried!")
// Check application service register user request is valid. // Check application service register user request is valid.
// The application service's ID is returned if so. // The application service's ID is returned if so.
appserviceID, err := validateApplicationService(cfg, req, r.Username) appserviceID, err := validateApplicationService(cfg, req, r.Username)
@ -525,11 +529,11 @@ func handleRegistrationFlow(
// Add Dummy to the list of completed registration stages // Add Dummy to the list of completed registration stages
sessions.AddCompletedStage(sessionID, authtypes.LoginTypeDummy) sessions.AddCompletedStage(sessionID, authtypes.LoginTypeDummy)
default: //default:
return util.JSONResponse{ // return util.JSONResponse{
Code: http.StatusNotImplemented, // Code: http.StatusNotImplemented,
JSON: jsonerror.Unknown("unknown/unimplemented auth type"), // JSON: jsonerror.Unknown("unknown/unimplemented auth type"),
} // }
} }
// Check if the user's registration flow has been completed successfully // Check if the user's registration flow has been completed successfully

View file

@ -130,7 +130,7 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
func (s *eventStateKeyStatements) bulkSelectEventStateKey( func (s *eventStateKeyStatements) bulkSelectEventStateKey(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID, ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) { ) (map[types.EventStateKeyNID]string, error) {
var nIDs pq.Int64Array nIDs := make(pq.Int64Array, len(eventStateKeyNIDs))
for i := range eventStateKeyNIDs { for i := range eventStateKeyNIDs {
nIDs[i] = int64(eventStateKeyNIDs[i]) nIDs[i] = int64(eventStateKeyNIDs[i])
} }