mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-29 08:18:27 +00:00
Implement filter POSTing and GETting. (#296)
* Implement filter POSTing and GETting. Signed-off-by: Jan Christian Grünhage <jan.christian@gruenhage.xyz> * Add missing '}' typo introduced during merge * Still trying to fix that merge... * Fix linting
This commit is contained in:
parent
e9314e5b30
commit
f6bda82366
5 changed files with 278 additions and 16 deletions
|
@ -0,0 +1,90 @@
|
||||||
|
// Copyright 2017 Jan Christian Grünhage
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package accounts
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
const filterSchema = `
|
||||||
|
-- Stores data about filters
|
||||||
|
CREATE TABLE IF NOT EXISTS account_filter (
|
||||||
|
-- The filter
|
||||||
|
filter TEXT NOT NULL,
|
||||||
|
-- The ID
|
||||||
|
id SERIAL UNIQUE,
|
||||||
|
-- The localpart of the Matrix user ID associated to this filter
|
||||||
|
localpart TEXT NOT NULL,
|
||||||
|
|
||||||
|
PRIMARY KEY(id, localpart)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS account_filter_localpart ON account_filter(localpart);
|
||||||
|
`
|
||||||
|
|
||||||
|
const selectFilterSQL = "" +
|
||||||
|
"SELECT filter FROM account_filter WHERE localpart = $1 AND id = $2"
|
||||||
|
|
||||||
|
const insertFilterSQL = "" +
|
||||||
|
"INSERT INTO account_filter (filter, id, localpart) VALUES ($1, DEFAULT, $2) RETURNING id"
|
||||||
|
|
||||||
|
const findMaxIDSQL = "" +
|
||||||
|
"SELECT MAX(id) FROM account_filter WHERE localpart = $1"
|
||||||
|
|
||||||
|
type filterStatements struct {
|
||||||
|
selectFilterStmt *sql.Stmt
|
||||||
|
insertFilterStmt *sql.Stmt
|
||||||
|
findMaxIDStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *filterStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
_, err = db.Exec(filterSchema)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.findMaxIDStmt, err = db.Prepare(findMaxIDSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *filterStatements) selectFilter(
|
||||||
|
ctx context.Context, localpart string, filterID string,
|
||||||
|
) (filter string, err error) {
|
||||||
|
err = s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filter)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *filterStatements) insertFilter(
|
||||||
|
ctx context.Context, filter string, localpart string,
|
||||||
|
) (err error) {
|
||||||
|
_, err = s.insertFilterStmt.ExecContext(ctx, filter, localpart)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *filterStatements) findMaxID(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) (id string, err error) {
|
||||||
|
err = s.findMaxIDStmt.QueryRowContext(ctx, localpart).Scan(&id)
|
||||||
|
return
|
||||||
|
}
|
|
@ -36,6 +36,7 @@ type Database struct {
|
||||||
memberships membershipStatements
|
memberships membershipStatements
|
||||||
accountDatas accountDataStatements
|
accountDatas accountDataStatements
|
||||||
threepids threepidStatements
|
threepids threepidStatements
|
||||||
|
filter filterStatements
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -70,7 +71,11 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
|
||||||
if err = t.prepare(db); err != nil {
|
if err = t.prepare(db); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &Database{db, partitions, a, p, m, ac, t, serverName}, nil
|
f := filterStatements{}
|
||||||
|
if err = f.prepare(db); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||||
|
@ -315,6 +320,26 @@ func (d *Database) GetThreePIDsForLocalpart(
|
||||||
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
|
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetFilter looks up the filter associated with a given local user and filter ID.
|
||||||
|
// Returns an error if no such filter exists or if there was an error taling to the database.
|
||||||
|
func (d *Database) GetFilter(
|
||||||
|
ctx context.Context, localpart string, filterID string,
|
||||||
|
) (string, error) {
|
||||||
|
return d.filter.selectFilter(ctx, localpart, filterID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PutFilter puts the passed filter into the database.
|
||||||
|
// Returns an error if something goes wrong.
|
||||||
|
func (d *Database) PutFilter(
|
||||||
|
ctx context.Context, localpart, filter string,
|
||||||
|
) (string, error) {
|
||||||
|
err := d.filter.insertFilter(ctx, filter, localpart)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return d.filter.findMaxID(ctx, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
// CheckAccountAvailability checks if the username/localpart is already present in the database.
|
// CheckAccountAvailability checks if the username/localpart is already present in the database.
|
||||||
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
|
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
|
||||||
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
|
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
// Copyright 2017 Jan Christian Grünhage
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package readers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
"github.com/matrix-org/gomatrix"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId}
|
||||||
|
func GetFilter(
|
||||||
|
req *http.Request, device *authtypes.Device, accountDB *accounts.Database, userID string, filterID string,
|
||||||
|
) util.JSONResponse {
|
||||||
|
if req.Method != http.MethodGet {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 405,
|
||||||
|
JSON: jsonerror.NotFound("Bad method"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if userID != device.UserID {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 403,
|
||||||
|
JSON: jsonerror.Forbidden("Cannot get filters for other users"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
|
if err != nil {
|
||||||
|
return httputil.LogThenError(req, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := accountDB.GetFilter(req.Context(), localpart, filterID)
|
||||||
|
if err != nil {
|
||||||
|
//TODO better error handling. This error message is *probably* right,
|
||||||
|
// but if there are obscure db errors, this will also be returned,
|
||||||
|
// even though it is not correct.
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 400,
|
||||||
|
JSON: jsonerror.NotFound("No such filter"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
filter := gomatrix.Filter{}
|
||||||
|
err = json.Unmarshal([]byte(res), &filter)
|
||||||
|
if err != nil {
|
||||||
|
httputil.LogThenError(req, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: filter,
|
||||||
|
}
|
||||||
|
}
|
|
@ -189,23 +189,18 @@ func Setup(
|
||||||
}),
|
}),
|
||||||
).Methods("GET")
|
).Methods("GET")
|
||||||
|
|
||||||
r0mux.Handle("/user/{userID}/filter",
|
|
||||||
common.MakeExternalAPI("make_filter", func(req *http.Request) util.JSONResponse {
|
r0mux.Handle("/user/{userId}/filter",
|
||||||
// TODO: Persist filter and return filter ID
|
common.MakeAuthAPI("put_filter", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
|
||||||
return util.JSONResponse{
|
vars := mux.Vars(req)
|
||||||
Code: 200,
|
return writers.PutFilter(req, device, accountDB, vars["userId"])
|
||||||
JSON: struct{}{},
|
|
||||||
}
|
|
||||||
}),
|
}),
|
||||||
).Methods("POST", "OPTIONS")
|
).Methods("POST", "OPTIONS")
|
||||||
|
|
||||||
r0mux.Handle("/user/{userID}/filter/{filterID}",
|
r0mux.Handle("/user/{userId}/filter/{filterId}",
|
||||||
common.MakeExternalAPI("filter", func(req *http.Request) util.JSONResponse {
|
common.MakeAuthAPI("get_filter", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
|
||||||
// TODO: Retrieve filter based on ID
|
vars := mux.Vars(req)
|
||||||
return util.JSONResponse{
|
return readers.GetFilter(req, device, accountDB, vars["userId"], vars["filterId"])
|
||||||
Code: 200,
|
|
||||||
JSON: struct{}{},
|
|
||||||
}
|
|
||||||
}),
|
}),
|
||||||
).Methods("GET")
|
).Methods("GET")
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,80 @@
|
||||||
|
// Copyright 2017 Jan Christian Grünhage
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package writers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
"github.com/matrix-org/gomatrix"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type filterResponse struct {
|
||||||
|
FilterID string `json:"filter_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
//PutFilter implements POST /_matrix/client/r0/user/{userId}/filter
|
||||||
|
func PutFilter(
|
||||||
|
req *http.Request, device *authtypes.Device, accountDB *accounts.Database, userID string,
|
||||||
|
) util.JSONResponse {
|
||||||
|
if req.Method != "POST" {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 405,
|
||||||
|
JSON: jsonerror.NotFound("Bad method"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if userID != device.UserID {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 403,
|
||||||
|
JSON: jsonerror.Forbidden("Cannot create filters for other users"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
||||||
|
if err != nil {
|
||||||
|
return httputil.LogThenError(req, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var filter gomatrix.Filter
|
||||||
|
|
||||||
|
if reqErr := httputil.UnmarshalJSONRequest(req, &filter); reqErr != nil {
|
||||||
|
return *reqErr
|
||||||
|
}
|
||||||
|
|
||||||
|
filterArray, err := json.Marshal(filter)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 400,
|
||||||
|
JSON: jsonerror.BadJSON("Filter is malformed"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
filterID, err := accountDB.PutFilter(req.Context(), localpart, string(filterArray))
|
||||||
|
if err != nil {
|
||||||
|
return httputil.LogThenError(req, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: 200,
|
||||||
|
JSON: filterResponse{FilterID: filterID},
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue