From 1f61537d661abebb99d74f43fa66005eb5c114c4 Mon Sep 17 00:00:00 2001
From: Anant Prakash <anantprakashjsr@gmail.com>
Date: Fri, 20 Apr 2018 20:22:21 +0530
Subject: [PATCH] Refactor username parsing function of clientapi:login (#432)

* Refactor username parse function of login

Signed-off-by: Anant Prakash <anantprakashjsr@gmail.com>

* Add tests for userutil

Signed-off-by: Anant Prakash <anantprakashjsr@gmail.com>
---
 .../dendrite/clientapi/routing/login.go       | 25 ++-----
 .../dendrite/clientapi/userutil/userutil.go   | 43 +++++++++++
 .../clientapi/userutil/userutil_test.go       | 71 +++++++++++++++++++
 3 files changed, 120 insertions(+), 19 deletions(-)
 create mode 100644 src/github.com/matrix-org/dendrite/clientapi/userutil/userutil.go
 create mode 100644 src/github.com/matrix-org/dendrite/clientapi/userutil/userutil_test.go

diff --git a/src/github.com/matrix-org/dendrite/clientapi/routing/login.go b/src/github.com/matrix-org/dendrite/clientapi/routing/login.go
index e0a4e632..3804da47 100644
--- a/src/github.com/matrix-org/dendrite/clientapi/routing/login.go
+++ b/src/github.com/matrix-org/dendrite/clientapi/routing/login.go
@@ -16,13 +16,13 @@ package routing
 
 import (
 	"net/http"
-	"strings"
 
 	"github.com/matrix-org/dendrite/clientapi/auth"
 	"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
 	"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
 	"github.com/matrix-org/dendrite/clientapi/httputil"
 	"github.com/matrix-org/dendrite/clientapi/jsonerror"
+	"github.com/matrix-org/dendrite/clientapi/userutil"
 	"github.com/matrix-org/dendrite/common/config"
 	"github.com/matrix-org/gomatrixserverlib"
 	"github.com/matrix-org/util"
@@ -82,24 +82,11 @@ func Login(
 
 		util.GetLogger(req.Context()).WithField("user", r.User).Info("Processing login request")
 
-		// r.User can either be a user ID or just the localpart... or other things maybe.
-		localpart := r.User
-		if strings.HasPrefix(r.User, "@") {
-			var domain gomatrixserverlib.ServerName
-			var err error
-			localpart, domain, err = gomatrixserverlib.SplitID('@', r.User)
-			if err != nil {
-				return util.JSONResponse{
-					Code: http.StatusBadRequest,
-					JSON: jsonerror.InvalidUsername("Invalid username"),
-				}
-			}
-
-			if domain != cfg.Matrix.ServerName {
-				return util.JSONResponse{
-					Code: http.StatusBadRequest,
-					JSON: jsonerror.InvalidUsername("User ID not ours"),
-				}
+		localpart, err := userutil.ParseUsernameParam(r.User, &cfg.Matrix.ServerName)
+		if err != nil {
+			return util.JSONResponse{
+				Code: http.StatusBadRequest,
+				JSON: jsonerror.InvalidUsername(err.Error()),
 			}
 		}
 
diff --git a/src/github.com/matrix-org/dendrite/clientapi/userutil/userutil.go b/src/github.com/matrix-org/dendrite/clientapi/userutil/userutil.go
new file mode 100644
index 00000000..de2d1959
--- /dev/null
+++ b/src/github.com/matrix-org/dendrite/clientapi/userutil/userutil.go
@@ -0,0 +1,43 @@
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package userutil
+
+import (
+	"errors"
+	"strings"
+
+	"github.com/matrix-org/gomatrixserverlib"
+)
+
+// ParseUsernameParam extracts localpart from usernameParam.
+// usernameParam can either be a user ID or just the localpart/username.
+// If serverName is passed, it is verified against the domain obtained from usernameParam (if present)
+// Returns error in case of invalid usernameParam.
+func ParseUsernameParam(usernameParam string, expectedServerName *gomatrixserverlib.ServerName) (string, error) {
+	localpart := usernameParam
+
+	if strings.HasPrefix(usernameParam, "@") {
+		lp, domain, err := gomatrixserverlib.SplitID('@', usernameParam)
+
+		if err != nil {
+			return "", errors.New("Invalid username")
+		}
+
+		if expectedServerName != nil && domain != *expectedServerName {
+			return "", errors.New("User ID does not belong to this server")
+		}
+
+		localpart = lp
+	}
+	return localpart, nil
+}
diff --git a/src/github.com/matrix-org/dendrite/clientapi/userutil/userutil_test.go b/src/github.com/matrix-org/dendrite/clientapi/userutil/userutil_test.go
new file mode 100644
index 00000000..2628642f
--- /dev/null
+++ b/src/github.com/matrix-org/dendrite/clientapi/userutil/userutil_test.go
@@ -0,0 +1,71 @@
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package userutil
+
+import (
+	"testing"
+
+	"github.com/matrix-org/gomatrixserverlib"
+)
+
+var (
+	localpart                                      = "somelocalpart"
+	serverName        gomatrixserverlib.ServerName = "someservername"
+	invalidServerName gomatrixserverlib.ServerName = "invalidservername"
+	goodUserID                                     = "@" + localpart + ":" + string(serverName)
+	badUserID                                      = "@bad:user:name@noservername:"
+)
+
+// TestGoodUserID checks that correct localpart is returned for a valid user ID.
+func TestGoodUserID(t *testing.T) {
+	lp, err := ParseUsernameParam(goodUserID, &serverName)
+
+	if err != nil {
+		t.Error("User ID Parsing failed for ", goodUserID, " with error: ", err.Error())
+	}
+
+	if lp != localpart {
+		t.Error("Incorrect username, returned: ", lp, " should be: ", localpart)
+	}
+}
+
+// TestWithLocalpartOnly checks that localpart is returned when usernameParam contains only localpart.
+func TestWithLocalpartOnly(t *testing.T) {
+	lp, err := ParseUsernameParam(localpart, &serverName)
+
+	if err != nil {
+		t.Error("User ID Parsing failed for ", localpart, " with error: ", err.Error())
+	}
+
+	if lp != localpart {
+		t.Error("Incorrect username, returned: ", lp, " should be: ", localpart)
+	}
+}
+
+// TestIncorrectDomain checks for error when there's server name mismatch.
+func TestIncorrectDomain(t *testing.T) {
+	_, err := ParseUsernameParam(goodUserID, &invalidServerName)
+
+	if err == nil {
+		t.Error("Invalid Domain should return an error")
+	}
+}
+
+// TestBadUserID checks that ParseUsernameParam fails for invalid user ID
+func TestBadUserID(t *testing.T) {
+	_, err := ParseUsernameParam(badUserID, &serverName)
+
+	if err == nil {
+		t.Error("Illegal User ID should return an error")
+	}
+}