Merge branch 'master' of github.com:matrix-org/dendrite into erikj/pagination

This commit is contained in:
Erik Johnston 2017-12-15 15:57:33 +00:00
commit 718a765ba8
89 changed files with 3161 additions and 963 deletions

View file

@ -1,7 +1,7 @@
language: go language: go
go: go:
- 1.8 - 1.8.x
- 1.9 - 1.9.x
env: env:
- TEST_SUITE="lint" - TEST_SUITE="lint"
@ -25,10 +25,3 @@ install:
script: script:
- ./scripts/travis-test.sh - ./scripts/travis-test.sh
notifications:
webhooks:
urls:
- "https://scalar.vector.im/api/neb/services/hooks/dHJhdmlzLWNpLyU0MGtlZ2FuJTNBbWF0cml4Lm9yZy8lMjFhWmthbkFuV0VkeGNSSVFrV24lM0FtYXRyaXgub3Jn"
on_success: change # always|never|change
on_failure: always
on_start: never

View file

@ -7,7 +7,6 @@
"deadcode", "deadcode",
"gocyclo", "gocyclo",
"ineffassign", "ineffassign",
"gas",
"misspell", "misspell",
"errcheck", "errcheck",
"vet", "vet",

View file

@ -10,7 +10,6 @@
"varcheck", "varcheck",
"structcheck", "structcheck",
"ineffassign", "ineffassign",
"gas",
"misspell", "misspell",
"unparam", "unparam",
"errcheck", "errcheck",

View file

@ -11,10 +11,15 @@
# when running the linters, speeding them up but using much more memory. # when running the linters, speeding them up but using much more memory.
set -eu set -eux
cd `dirname $0`/..
export GOPATH="$(pwd):$(pwd)/vendor" export GOPATH="$(pwd):$(pwd)/vendor"
export PATH="$PATH:$(pwd)/bin"
# prefer the versions of gometalinter and the linters that we install
# to anythign that ends up on the PATH.
export PATH="$(pwd)/bin:$PATH"
args="" args=""
if [ ${1:-""} = "fast" ] if [ ${1:-""} = "fast" ]

View file

@ -81,7 +81,7 @@ type txnReq struct {
func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, error) { func (t *txnReq) processTransaction() (*gomatrixserverlib.RespSend, error) {
// Check the event signatures // Check the event signatures
if err := gomatrixserverlib.VerifyEventSignatures(t.context, t.PDUs, t.keys); err != nil { if err := gomatrixserverlib.VerifyAllEventSignatures(t.context, t.PDUs, t.keys); err != nil {
return nil, err return nil, err
} }

View file

@ -69,6 +69,11 @@ type OutputRoomEventWriter interface {
WriteOutputEvents(ctx context.Context, roomID string, updates []api.OutputEvent) error WriteOutputEvents(ctx context.Context, roomID string, updates []api.OutputEvent) error
} }
// processRoomEvent can only be called once at a time
//
// TODO(#375): This should be rewritten to allow concurrent calls. The
// difficulty is in ensuring that we correctly annotate events with the correct
// state deltas when sending to kafka streams
func processRoomEvent( func processRoomEvent(
ctx context.Context, ctx context.Context,
db RoomEventDatabase, db RoomEventDatabase,

View file

@ -19,6 +19,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"sync"
"github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go"
@ -35,6 +36,8 @@ type RoomserverInputAPI struct {
// The kafkaesque topic to output new room events to. // The kafkaesque topic to output new room events to.
// This is the name used in kafka to identify the stream to write events to. // This is the name used in kafka to identify the stream to write events to.
OutputRoomEventTopic string OutputRoomEventTopic string
// Protects calls to processRoomEvent
mutex sync.Mutex
} }
// WriteOutputEvents implements OutputRoomEventWriter // WriteOutputEvents implements OutputRoomEventWriter
@ -63,6 +66,10 @@ func (r *RoomserverInputAPI) InputRoomEvents(
response *api.InputRoomEventsResponse, response *api.InputRoomEventsResponse,
) error { ) error {
for i := range request.InputRoomEvents { for i := range request.InputRoomEvents {
// We lock as processRoomEvent can ony be called once at a time
r.mutex.Lock()
defer r.mutex.Unlock()
if err := processRoomEvent(ctx, r.DB, r, request.InputRoomEvents[i]); err != nil { if err := processRoomEvent(ctx, r.DB, r, request.InputRoomEvents[i]); err != nil {
return err return err
} }

View file

@ -42,6 +42,7 @@ import (
// | // |
// 7 <----- latest // 7 <----- latest
// //
// Can only be called once at a time
func updateLatestEvents( func updateLatestEvents(
ctx context.Context, ctx context.Context,
db RoomEventDatabase, db RoomEventDatabase,

View file

@ -19,6 +19,9 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
// Import the postgres database driver. // Import the postgres database driver.
_ "github.com/lib/pq" _ "github.com/lib/pq"
@ -86,13 +89,17 @@ func (d *SyncServerDatabase) AllJoinedUsersInRooms(ctx context.Context) (map[str
// Events lookups a list of event by their event ID. // Events lookups a list of event by their event ID.
// Returns a list of events matching the requested IDs found in the database. // Returns a list of events matching the requested IDs found in the database.
// If an event is not found in the database then it will be omitted from the list. // If an event is not found in the database then it will be omitted from the list.
// Returns an error if there was a problem talking with the database // Returns an error if there was a problem talking with the database.
// Does not include any transaction IDs in the returned events.
func (d *SyncServerDatabase) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) { func (d *SyncServerDatabase) Events(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) {
streamEvents, err := d.events.selectEvents(ctx, nil, eventIDs) streamEvents, err := d.events.selectEvents(ctx, nil, eventIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return streamEventsToEvents(streamEvents), nil
// We don't include a device here as we only include transaction IDs in
// incremental syncs.
return streamEventsToEvents(nil, streamEvents), nil
} }
// WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races
@ -208,10 +215,14 @@ func (d *SyncServerDatabase) syncStreamPositionTx(
return types.StreamPosition(maxID), nil return types.StreamPosition(maxID), nil
} }
// IncrementalSync returns all the data needed in order to create an incremental sync response. // IncrementalSync returns all the data needed in order to create an incremental
// sync response for the given user. Events returned will include any client
// transaction IDs associated with the given device. These transaction IDs come
// from when the device sent the event via an API that included a transaction
// ID.
func (d *SyncServerDatabase) IncrementalSync( func (d *SyncServerDatabase) IncrementalSync(
ctx context.Context, ctx context.Context,
userID string, device authtypes.Device,
fromPos, toPos types.StreamPosition, fromPos, toPos types.StreamPosition,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
) (*types.Response, error) { ) (*types.Response, error) {
@ -226,21 +237,21 @@ func (d *SyncServerDatabase) IncrementalSync(
// joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions. // joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions.
// This works out what the 'state' key should be for each room as well as which membership block // This works out what the 'state' key should be for each room as well as which membership block
// to put the room into. // to put the room into.
deltas, err := d.getStateDeltas(ctx, txn, fromPos, toPos, userID) deltas, err := d.getStateDeltas(ctx, &device, txn, fromPos, toPos, device.UserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
res := types.NewResponse(toPos) res := types.NewResponse(toPos)
for _, delta := range deltas { for _, delta := range deltas {
err = d.addRoomDeltaToResponse(ctx, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res) err = d.addRoomDeltaToResponse(ctx, &device, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
// TODO: This should be done in getStateDeltas // TODO: This should be done in getStateDeltas
if err = d.addInvitesToResponse(ctx, txn, userID, fromPos, toPos, res); err != nil { if err = d.addInvitesToResponse(ctx, txn, device.UserID, fromPos, toPos, res); err != nil {
return nil, err return nil, err
} }
@ -292,7 +303,10 @@ func (d *SyncServerDatabase) CompleteSync(
if err != nil { if err != nil {
return nil, err return nil, err
} }
recentEvents := streamEventsToEvents(recentStreamEvents)
// We don't include a device here as we don't need to send down
// transaction IDs for complete syncs
recentEvents := streamEventsToEvents(nil, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents) stateEvents = removeDuplicates(stateEvents, recentEvents)
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
@ -391,7 +405,9 @@ func (d *SyncServerDatabase) addInvitesToResponse(
// addRoomDeltaToResponse adds a room state delta to a sync response // addRoomDeltaToResponse adds a room state delta to a sync response
func (d *SyncServerDatabase) addRoomDeltaToResponse( func (d *SyncServerDatabase) addRoomDeltaToResponse(
ctx context.Context, txn *sql.Tx, ctx context.Context,
device *authtypes.Device,
txn *sql.Tx,
fromPos, toPos types.StreamPosition, fromPos, toPos types.StreamPosition,
delta stateDelta, delta stateDelta,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
@ -413,7 +429,7 @@ func (d *SyncServerDatabase) addRoomDeltaToResponse(
if err != nil { if err != nil {
return err return err
} }
recentEvents := streamEventsToEvents(recentStreamEvents) recentEvents := streamEventsToEvents(device, recentStreamEvents)
delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back
// Don't bother appending empty room entries // Don't bother appending empty room entries
@ -531,7 +547,7 @@ func (d *SyncServerDatabase) fetchMissingStateEvents(
} }
func (d *SyncServerDatabase) getStateDeltas( func (d *SyncServerDatabase) getStateDeltas(
ctx context.Context, txn *sql.Tx, ctx context.Context, device *authtypes.Device, txn *sql.Tx,
fromPos, toPos types.StreamPosition, userID string, fromPos, toPos types.StreamPosition, userID string,
) ([]stateDelta, error) { ) ([]stateDelta, error) {
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
@ -580,7 +596,7 @@ func (d *SyncServerDatabase) getStateDeltas(
deltas = append(deltas, stateDelta{ deltas = append(deltas, stateDelta{
membership: membership, membership: membership,
membershipPos: ev.streamPosition, membershipPos: ev.streamPosition,
stateEvents: streamEventsToEvents(stateStreamEvents), stateEvents: streamEventsToEvents(device, stateStreamEvents),
roomID: roomID, roomID: roomID,
}) })
break break
@ -596,7 +612,7 @@ func (d *SyncServerDatabase) getStateDeltas(
for _, joinedRoomID := range joinedRoomIDs { for _, joinedRoomID := range joinedRoomIDs {
deltas = append(deltas, stateDelta{ deltas = append(deltas, stateDelta{
membership: "join", membership: "join",
stateEvents: streamEventsToEvents(state[joinedRoomID]), stateEvents: streamEventsToEvents(device, state[joinedRoomID]),
roomID: joinedRoomID, roomID: joinedRoomID,
}) })
} }
@ -604,10 +620,25 @@ func (d *SyncServerDatabase) getStateDeltas(
return deltas, nil return deltas, nil
} }
func streamEventsToEvents(in []streamEvent) []gomatrixserverlib.Event { // streamEventsToEvents converts streamEvent to Event. If device is non-nil and
// matches the streamevent.transactionID device then the transaction ID gets
// added to the unsigned section of the output event.
func streamEventsToEvents(device *authtypes.Device, in []streamEvent) []gomatrixserverlib.Event {
out := make([]gomatrixserverlib.Event, len(in)) out := make([]gomatrixserverlib.Event, len(in))
for i := 0; i < len(in); i++ { for i := 0; i < len(in); i++ {
out[i] = in[i].Event out[i] = in[i].Event
if device != nil && in[i].transactionID != nil {
if device.UserID == in[i].Sender() && device.ID == in[i].transactionID.DeviceID {
err := out[i].SetUnsignedField(
"transaction_id", in[i].transactionID.TransactionID,
)
if err != nil {
logrus.WithFields(logrus.Fields{
"event_id": out[i].EventID(),
}).WithError(err).Warnf("Failed to add transaction ID to event")
}
}
}
} }
return out return out
} }

View file

@ -123,7 +123,7 @@ func (n *Notifier) GetListener(req syncRequest) UserStreamListener {
n.removeEmptyUserStreams() n.removeEmptyUserStreams()
return n.fetchUserStream(req.userID, true).GetListener(req.ctx) return n.fetchUserStream(req.device.UserID, true).GetListener(req.ctx)
} }
// Load the membership states required to notify users correctly. // Load the membership states required to notify users correctly.

View file

@ -21,6 +21,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -262,7 +264,7 @@ func waitForEvents(n *Notifier, req syncRequest) (types.StreamPosition, error) {
select { select {
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
return types.StreamPosition(0), fmt.Errorf( return types.StreamPosition(0), fmt.Errorf(
"waitForEvents timed out waiting for %s (pos=%d)", req.userID, req.since, "waitForEvents timed out waiting for %s (pos=%d)", req.device.UserID, req.since,
) )
case <-listener.GetNotifyChannel(*req.since): case <-listener.GetNotifyChannel(*req.since):
p := listener.GetStreamPosition() p := listener.GetStreamPosition()
@ -280,7 +282,7 @@ func waitForBlocking(s *UserStream, numBlocking uint) {
func newTestSyncRequest(userID string, since types.StreamPosition) syncRequest { func newTestSyncRequest(userID string, since types.StreamPosition) syncRequest {
return syncRequest{ return syncRequest{
userID: userID, device: authtypes.Device{UserID: userID},
timeout: 1 * time.Minute, timeout: 1 * time.Minute,
since: &since, since: &since,
wantFullState: false, wantFullState: false,

View file

@ -20,6 +20,8 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/util" "github.com/matrix-org/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -31,7 +33,7 @@ const defaultTimelineLimit = 20
// syncRequest represents a /sync request, with sensible defaults/sanity checks applied. // syncRequest represents a /sync request, with sensible defaults/sanity checks applied.
type syncRequest struct { type syncRequest struct {
ctx context.Context ctx context.Context
userID string device authtypes.Device
limit int limit int
timeout time.Duration timeout time.Duration
since *types.StreamPosition // nil means that no since token was supplied since *types.StreamPosition // nil means that no since token was supplied
@ -39,7 +41,7 @@ type syncRequest struct {
log *log.Entry log *log.Entry
} }
func newSyncRequest(req *http.Request, userID string) (*syncRequest, error) { func newSyncRequest(req *http.Request, device authtypes.Device) (*syncRequest, error) {
timeout := getTimeout(req.URL.Query().Get("timeout")) timeout := getTimeout(req.URL.Query().Get("timeout"))
fullState := req.URL.Query().Get("full_state") fullState := req.URL.Query().Get("full_state")
wantFullState := fullState != "" && fullState != "false" wantFullState := fullState != "" && fullState != "false"
@ -50,7 +52,7 @@ func newSyncRequest(req *http.Request, userID string) (*syncRequest, error) {
// TODO: Additional query params: set_presence, filter // TODO: Additional query params: set_presence, filter
return &syncRequest{ return &syncRequest{
ctx: req.Context(), ctx: req.Context(),
userID: userID, device: device,
timeout: timeout, timeout: timeout,
since: since, since: since,
wantFullState: wantFullState, wantFullState: wantFullState,

View file

@ -48,7 +48,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
// Extract values from request // Extract values from request
logger := util.GetLogger(req.Context()) logger := util.GetLogger(req.Context())
userID := device.UserID userID := device.UserID
syncReq, err := newSyncRequest(req, userID) syncReq, err := newSyncRequest(req, *device)
if err != nil { if err != nil {
return util.JSONResponse{ return util.JSONResponse{
Code: 400, Code: 400,
@ -122,16 +122,16 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
func (rp *RequestPool) currentSyncForUser(req syncRequest, currentPos types.StreamPosition) (res *types.Response, err error) { func (rp *RequestPool) currentSyncForUser(req syncRequest, currentPos types.StreamPosition) (res *types.Response, err error) {
// TODO: handle ignored users // TODO: handle ignored users
if req.since == nil { if req.since == nil {
res, err = rp.db.CompleteSync(req.ctx, req.userID, req.limit) res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit)
} else { } else {
res, err = rp.db.IncrementalSync(req.ctx, req.userID, *req.since, currentPos, req.limit) res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, currentPos, req.limit)
} }
if err != nil { if err != nil {
return return
} }
res, err = rp.appendAccountData(res, req.userID, req, currentPos) res, err = rp.appendAccountData(res, req.device.UserID, req, currentPos)
return return
} }

4
vendor/manifest vendored
View file

@ -10,7 +10,7 @@
{ {
"importpath": "github.com/alecthomas/gometalinter", "importpath": "github.com/alecthomas/gometalinter",
"repository": "https://github.com/alecthomas/gometalinter", "repository": "https://github.com/alecthomas/gometalinter",
"revision": "0262fb20957a4c2d3bb7c834a6a125ae3884a2c6", "revision": "b8b1f84ae8cb72e7870785840eab2d6c6355aa9f",
"branch": "master" "branch": "master"
}, },
{ {
@ -135,7 +135,7 @@
{ {
"importpath": "github.com/matrix-org/gomatrixserverlib", "importpath": "github.com/matrix-org/gomatrixserverlib",
"repository": "https://github.com/matrix-org/gomatrixserverlib", "repository": "https://github.com/matrix-org/gomatrixserverlib",
"revision": "8540d3dfc13c797cd3200640bc06e0286ab355aa", "revision": "afa71391f946312c40639a419045e06b8ff2309a",
"branch": "master" "branch": "master"
}, },
{ {

View file

@ -110,6 +110,21 @@ Here is an example configuration file:
} }
``` ```
#### `Format` key
The default `Format` key places the different fields of an `Issue` into a template. this
corresponds to the `--format` option command-line flag.
Default `Format`:
```
Format: "{{.Path}}:{{.Line}}:{{if .Col}}{{.Col}}{{end}}:{{.Severity}}: {{.Message}} ({{.Linter}})"
```
#### Format Methods
* `{{.Path.Relative}}` - equivalent to `{{.Path}}` which outputs a relative path to the file
* `{{.Path.Abs}}` - outputs an absolute path to the file
### Adding Custom linters ### Adding Custom linters
Linters can be added and customized from the config file using the `Linters` field. Linters can be added and customized from the config file using the `Linters` field.

View file

@ -102,6 +102,9 @@ func (v *visitor) Visit(node ast.Node) ast.Visitor {
for _, val := range node.Values { for _, val := range node.Values {
ast.Walk(v, val) ast.Walk(v, val)
} }
if node.Type != nil {
ast.Walk(v, node.Type)
}
return nil return nil
case *ast.FuncDecl: case *ast.FuncDecl:

View file

@ -0,0 +1,27 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View file

@ -0,0 +1,167 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package importgraph computes the forward and reverse import
// dependency graphs for all packages in a Go workspace.
package importgraph // import "golang.org/x/tools/refactor/importgraph"
import (
"go/build"
"sync"
"golang.org/x/tools/go/buildutil"
)
// A Graph is an import dependency graph, either forward or reverse.
//
// The graph maps each node (a package import path) to the set of its
// successors in the graph. For a forward graph, this is the set of
// imported packages (prerequisites); for a reverse graph, it is the set
// of importing packages (clients).
//
// Graph construction inspects all imports in each package's directory,
// including those in _test.go files, so the resulting graph may be cyclic.
type Graph map[string]map[string]bool
func (g Graph) addEdge(from, to string) {
edges := g[from]
if edges == nil {
edges = make(map[string]bool)
g[from] = edges
}
edges[to] = true
}
// Search returns all the nodes of the graph reachable from
// any of the specified roots, by following edges forwards.
// Relationally, this is the reflexive transitive closure.
func (g Graph) Search(roots ...string) map[string]bool {
seen := make(map[string]bool)
var visit func(x string)
visit = func(x string) {
if !seen[x] {
seen[x] = true
for y := range g[x] {
visit(y)
}
}
}
for _, root := range roots {
visit(root)
}
return seen
}
// Build scans the specified Go workspace and builds the forward and
// reverse import dependency graphs for all its packages.
// It also returns a mapping from canonical import paths to errors for packages
// whose loading was not entirely successful.
// A package may appear in the graph and in the errors mapping.
// All package paths are canonical and may contain "/vendor/".
func Build(ctxt *build.Context) (forward, reverse Graph, errors map[string]error) {
type importEdge struct {
from, to string
}
type pathError struct {
path string
err error
}
ch := make(chan interface{})
go func() {
sema := make(chan int, 20) // I/O concurrency limiting semaphore
var wg sync.WaitGroup
buildutil.ForEachPackage(ctxt, func(path string, err error) {
if err != nil {
ch <- pathError{path, err}
return
}
wg.Add(1)
go func() {
defer wg.Done()
sema <- 1
bp, err := ctxt.Import(path, "", 0)
<-sema
if err != nil {
if _, ok := err.(*build.NoGoError); ok {
// empty directory is not an error
} else {
ch <- pathError{path, err}
}
// Even in error cases, Import usually returns a package.
}
// absolutize resolves an import path relative
// to the current package bp.
// The absolute form may contain "vendor".
//
// The vendoring feature slows down Build by 3×.
// Here are timings from a 1400 package workspace:
// 1100ms: current code (with vendor check)
// 880ms: with a nonblocking cache around ctxt.IsDir
// 840ms: nonblocking cache with duplicate suppression
// 340ms: original code (no vendor check)
// TODO(adonovan): optimize, somehow.
memo := make(map[string]string)
absolutize := func(path string) string {
canon, ok := memo[path]
if !ok {
sema <- 1
bp2, _ := ctxt.Import(path, bp.Dir, build.FindOnly)
<-sema
if bp2 != nil {
canon = bp2.ImportPath
} else {
canon = path
}
memo[path] = canon
}
return canon
}
if bp != nil {
for _, imp := range bp.Imports {
ch <- importEdge{path, absolutize(imp)}
}
for _, imp := range bp.TestImports {
ch <- importEdge{path, absolutize(imp)}
}
for _, imp := range bp.XTestImports {
ch <- importEdge{path, absolutize(imp)}
}
}
}()
})
wg.Wait()
close(ch)
}()
forward = make(Graph)
reverse = make(Graph)
for e := range ch {
switch e := e.(type) {
case pathError:
if errors == nil {
errors = make(map[string]error)
}
errors[e.path] = e.err
case importEdge:
if e.to == "C" {
continue // "C" is fake
}
forward.addEdge(e.from, e.to)
reverse.addEdge(e.to, e.from)
}
}
return forward, reverse, errors
}

View file

@ -0,0 +1,16 @@
package main // import "honnef.co/go/tools/cmd/errcheck-ng"
import (
"os"
"honnef.co/go/tools/errcheck"
"honnef.co/go/tools/lint/lintutil"
)
func main() {
c := lintutil.CheckerConfig{
Checker: errcheck.NewChecker(),
ExitNonZero: true,
}
lintutil.ProcessArgs("errcheck-ng", []lintutil.CheckerConfig{c}, os.Args[1:])
}

View file

@ -1,20 +0,0 @@
Copyright (c) 2016 Dominik Honnef
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -13,6 +13,9 @@ func main() {
fs.Parse(os.Args[1:]) fs.Parse(os.Args[1:])
c := simple.NewChecker() c := simple.NewChecker()
c.CheckGenerated = *gen c.CheckGenerated = *gen
cfg := lintutil.CheckerConfig{
lintutil.ProcessFlagSet(c, fs) Checker: c,
ExitNonZero: true,
}
lintutil.ProcessFlagSet([]lintutil.CheckerConfig{cfg}, fs)
} }

View file

@ -0,0 +1,401 @@
// keyify transforms unkeyed struct literals into a keyed ones.
package main
import (
"bytes"
"encoding/json"
"flag"
"fmt"
"go/ast"
"go/build"
"go/constant"
"go/printer"
"go/token"
"go/types"
"log"
"os"
"path/filepath"
"honnef.co/go/tools/version"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/buildutil"
"golang.org/x/tools/go/loader"
)
var (
fRecursive bool
fOneLine bool
fJSON bool
fMinify bool
fModified bool
fVersion bool
)
func init() {
flag.BoolVar(&fRecursive, "r", false, "keyify struct initializers recursively")
flag.BoolVar(&fOneLine, "o", false, "print new struct initializer on a single line")
flag.BoolVar(&fJSON, "json", false, "print new struct initializer as JSON")
flag.BoolVar(&fMinify, "m", false, "omit fields that are set to their zero value")
flag.BoolVar(&fModified, "modified", false, "read an archive of modified files from standard input")
flag.BoolVar(&fVersion, "version", false, "Print version and exit")
}
func usage() {
fmt.Printf("Usage: %s [flags] <position>\n\n", os.Args[0])
flag.PrintDefaults()
}
func main() {
log.SetFlags(0)
flag.Usage = usage
flag.Parse()
if fVersion {
version.Print()
os.Exit(0)
}
if flag.NArg() != 1 {
flag.Usage()
os.Exit(2)
}
pos := flag.Args()[0]
name, start, _, err := parsePos(pos)
if err != nil {
log.Fatal(err)
}
eval, err := filepath.EvalSymlinks(name)
if err != nil {
log.Fatal(err)
}
name, err = filepath.Abs(eval)
if err != nil {
log.Fatal(err)
}
cwd, err := os.Getwd()
if err != nil {
log.Fatal(err)
}
ctx := &build.Default
if fModified {
overlay, err := buildutil.ParseOverlayArchive(os.Stdin)
if err != nil {
log.Fatal(err)
}
ctx = buildutil.OverlayContext(ctx, overlay)
}
bpkg, err := buildutil.ContainingPackage(ctx, cwd, name)
if err != nil {
log.Fatal(err)
}
conf := &loader.Config{
Build: ctx,
}
conf.TypeCheckFuncBodies = func(s string) bool {
return s == bpkg.ImportPath || s == bpkg.ImportPath+"_test"
}
conf.ImportWithTests(bpkg.ImportPath)
lprog, err := conf.Load()
if err != nil {
log.Fatal(err)
}
var tf *token.File
var af *ast.File
pkg := lprog.InitialPackages()[0]
for _, ff := range pkg.Files {
file := lprog.Fset.File(ff.Pos())
if file.Name() == name {
af = ff
tf = file
break
}
}
tstart, tend, err := fileOffsetToPos(tf, start, start)
if err != nil {
log.Fatal(err)
}
path, _ := astutil.PathEnclosingInterval(af, tstart, tend)
var complit *ast.CompositeLit
for _, p := range path {
if p, ok := p.(*ast.CompositeLit); ok {
complit = p
break
}
}
if complit == nil {
log.Fatal("no composite literal found near point")
}
if len(complit.Elts) == 0 {
printComplit(complit, complit, lprog.Fset, lprog.Fset)
return
}
if _, ok := complit.Elts[0].(*ast.KeyValueExpr); ok {
lit := complit
if fOneLine {
lit = copyExpr(complit, 1).(*ast.CompositeLit)
}
printComplit(complit, lit, lprog.Fset, lprog.Fset)
return
}
_, ok := pkg.TypeOf(complit).Underlying().(*types.Struct)
if !ok {
log.Fatal("not a struct initialiser")
return
}
newComplit, lines := keyify(pkg, complit)
newFset := token.NewFileSet()
newFile := newFset.AddFile("", -1, lines)
for i := 1; i <= lines; i++ {
newFile.AddLine(i)
}
printComplit(complit, newComplit, lprog.Fset, newFset)
}
func keyify(
pkg *loader.PackageInfo,
complit *ast.CompositeLit,
) (*ast.CompositeLit, int) {
var calcPos func(int) token.Pos
if fOneLine {
calcPos = func(int) token.Pos { return token.Pos(1) }
} else {
calcPos = func(i int) token.Pos { return token.Pos(2 + i) }
}
st, _ := pkg.TypeOf(complit).Underlying().(*types.Struct)
newComplit := &ast.CompositeLit{
Type: complit.Type,
Lbrace: 1,
Rbrace: token.Pos(st.NumFields() + 2),
}
if fOneLine {
newComplit.Rbrace = 1
}
numLines := 2 + st.NumFields()
n := 0
for i := 0; i < st.NumFields(); i++ {
field := st.Field(i)
val := complit.Elts[i]
if fRecursive {
if val2, ok := val.(*ast.CompositeLit); ok {
if _, ok := pkg.TypeOf(val2.Type).Underlying().(*types.Struct); ok {
var lines int
numLines += lines
val, lines = keyify(pkg, val2)
}
}
}
_, isIface := st.Field(i).Type().Underlying().(*types.Interface)
if fMinify && (isNil(val, pkg) || (!isIface && isZero(val, pkg))) {
continue
}
elt := &ast.KeyValueExpr{
Key: &ast.Ident{NamePos: calcPos(n), Name: field.Name()},
Value: copyExpr(val, calcPos(n)),
}
newComplit.Elts = append(newComplit.Elts, elt)
n++
}
return newComplit, numLines
}
func isNil(val ast.Expr, pkg *loader.PackageInfo) bool {
ident, ok := val.(*ast.Ident)
if !ok {
return false
}
if _, ok := pkg.ObjectOf(ident).(*types.Nil); ok {
return true
}
if c, ok := pkg.ObjectOf(ident).(*types.Const); ok {
if c.Val().Kind() != constant.Bool {
return false
}
return !constant.BoolVal(c.Val())
}
return false
}
func isZero(val ast.Expr, pkg *loader.PackageInfo) bool {
switch val := val.(type) {
case *ast.BasicLit:
switch val.Value {
case `""`, "``", "0", "0.0", "0i", "0.":
return true
default:
return false
}
case *ast.Ident:
return isNil(val, pkg)
case *ast.CompositeLit:
typ := pkg.TypeOf(val.Type)
if typ == nil {
return false
}
isIface := false
switch typ := typ.Underlying().(type) {
case *types.Struct:
case *types.Array:
_, isIface = typ.Elem().Underlying().(*types.Interface)
default:
return false
}
for _, elt := range val.Elts {
if isNil(elt, pkg) || (!isIface && !isZero(elt, pkg)) {
return false
}
}
return true
}
return false
}
func printComplit(oldlit, newlit *ast.CompositeLit, oldfset, newfset *token.FileSet) {
buf := &bytes.Buffer{}
cfg := printer.Config{Mode: printer.UseSpaces | printer.TabIndent, Tabwidth: 8}
_ = cfg.Fprint(buf, newfset, newlit)
if fJSON {
output := struct {
Start int `json:"start"`
End int `json:"end"`
Replacement string `json:"replacement"`
}{
oldfset.Position(oldlit.Pos()).Offset,
oldfset.Position(oldlit.End()).Offset,
buf.String(),
}
_ = json.NewEncoder(os.Stdout).Encode(output)
} else {
fmt.Println(buf.String())
}
}
func copyExpr(expr ast.Expr, line token.Pos) ast.Expr {
switch expr := expr.(type) {
case *ast.BasicLit:
cp := *expr
cp.ValuePos = 0
return &cp
case *ast.BinaryExpr:
cp := *expr
cp.X = copyExpr(cp.X, line)
cp.OpPos = 0
cp.Y = copyExpr(cp.Y, line)
return &cp
case *ast.CallExpr:
cp := *expr
cp.Fun = copyExpr(cp.Fun, line)
cp.Lparen = 0
for i, v := range cp.Args {
cp.Args[i] = copyExpr(v, line)
}
if cp.Ellipsis != 0 {
cp.Ellipsis = line
}
cp.Rparen = 0
return &cp
case *ast.CompositeLit:
cp := *expr
cp.Type = copyExpr(cp.Type, line)
cp.Lbrace = 0
for i, v := range cp.Elts {
cp.Elts[i] = copyExpr(v, line)
}
cp.Rbrace = 0
return &cp
case *ast.Ident:
cp := *expr
cp.NamePos = 0
return &cp
case *ast.IndexExpr:
cp := *expr
cp.X = copyExpr(cp.X, line)
cp.Lbrack = 0
cp.Index = copyExpr(cp.Index, line)
cp.Rbrack = 0
return &cp
case *ast.KeyValueExpr:
cp := *expr
cp.Key = copyExpr(cp.Key, line)
cp.Colon = 0
cp.Value = copyExpr(cp.Value, line)
return &cp
case *ast.ParenExpr:
cp := *expr
cp.Lparen = 0
cp.X = copyExpr(cp.X, line)
cp.Rparen = 0
return &cp
case *ast.SelectorExpr:
cp := *expr
cp.X = copyExpr(cp.X, line)
cp.Sel = copyExpr(cp.Sel, line).(*ast.Ident)
return &cp
case *ast.SliceExpr:
cp := *expr
cp.X = copyExpr(cp.X, line)
cp.Lbrack = 0
cp.Low = copyExpr(cp.Low, line)
cp.High = copyExpr(cp.High, line)
cp.Max = copyExpr(cp.Max, line)
cp.Rbrack = 0
return &cp
case *ast.StarExpr:
cp := *expr
cp.Star = 0
cp.X = copyExpr(cp.X, line)
return &cp
case *ast.TypeAssertExpr:
cp := *expr
cp.X = copyExpr(cp.X, line)
cp.Lparen = 0
cp.Type = copyExpr(cp.Type, line)
cp.Rparen = 0
return &cp
case *ast.UnaryExpr:
cp := *expr
cp.OpPos = 0
cp.X = copyExpr(cp.X, line)
return &cp
case *ast.MapType:
cp := *expr
cp.Map = 0
cp.Key = copyExpr(cp.Key, line)
cp.Value = copyExpr(cp.Value, line)
return &cp
case *ast.ArrayType:
cp := *expr
cp.Lbrack = 0
cp.Len = copyExpr(cp.Len, line)
cp.Elt = copyExpr(cp.Elt, line)
return &cp
case *ast.Ellipsis:
cp := *expr
cp.Elt = copyExpr(cp.Elt, line)
cp.Ellipsis = line
return &cp
case *ast.InterfaceType:
cp := *expr
cp.Interface = 0
return &cp
case *ast.StructType:
cp := *expr
cp.Struct = 0
return &cp
case *ast.FuncLit:
return expr
case *ast.ChanType:
cp := *expr
cp.Arrow = 0
cp.Begin = 0
cp.Value = copyExpr(cp.Value, line)
return &cp
case nil:
return nil
default:
panic(fmt.Sprintf("shouldn't happen: unknown ast.Expr of type %T", expr))
}
return nil
}

View file

@ -0,0 +1,71 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"go/token"
"strconv"
"strings"
)
func parseOctothorpDecimal(s string) int {
if s != "" && s[0] == '#' {
if s, err := strconv.ParseInt(s[1:], 10, 32); err == nil {
return int(s)
}
}
return -1
}
func parsePos(pos string) (filename string, startOffset, endOffset int, err error) {
if pos == "" {
err = fmt.Errorf("no source position specified")
return
}
colon := strings.LastIndex(pos, ":")
if colon < 0 {
err = fmt.Errorf("bad position syntax %q", pos)
return
}
filename, offset := pos[:colon], pos[colon+1:]
startOffset = -1
endOffset = -1
if hyphen := strings.Index(offset, ","); hyphen < 0 {
// e.g. "foo.go:#123"
startOffset = parseOctothorpDecimal(offset)
endOffset = startOffset
} else {
// e.g. "foo.go:#123,#456"
startOffset = parseOctothorpDecimal(offset[:hyphen])
endOffset = parseOctothorpDecimal(offset[hyphen+1:])
}
if startOffset < 0 || endOffset < 0 {
err = fmt.Errorf("invalid offset %q in query position", offset)
return
}
return
}
func fileOffsetToPos(file *token.File, startOffset, endOffset int) (start, end token.Pos, err error) {
// Range check [start..end], inclusive of both end-points.
if 0 <= startOffset && startOffset <= file.Size() {
start = file.Pos(int(startOffset))
} else {
err = fmt.Errorf("start position is beyond end of file")
return
}
if 0 <= endOffset && endOffset <= file.Size() {
end = file.Pos(int(endOffset))
} else {
err = fmt.Errorf("end position is beyond end of file")
return
}
return
}

View file

@ -1,20 +0,0 @@
Copyright (c) 2016 Dominik Honnef
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -4,42 +4,23 @@ package main // import "honnef.co/go/tools/cmd/megacheck"
import ( import (
"os" "os"
"honnef.co/go/tools/lint"
"honnef.co/go/tools/lint/lintutil" "honnef.co/go/tools/lint/lintutil"
"honnef.co/go/tools/simple" "honnef.co/go/tools/simple"
"honnef.co/go/tools/staticcheck" "honnef.co/go/tools/staticcheck"
"honnef.co/go/tools/unused" "honnef.co/go/tools/unused"
) )
type Checker struct {
Checkers []lint.Checker
}
func (c *Checker) Init(prog *lint.Program) {
for _, cc := range c.Checkers {
cc.Init(prog)
}
}
func (c *Checker) Funcs() map[string]lint.Func {
fns := map[string]lint.Func{}
for _, cc := range c.Checkers {
for k, v := range cc.Funcs() {
fns[k] = v
}
}
return fns
}
func main() { func main() {
var flags struct { var flags struct {
staticcheck struct { staticcheck struct {
enabled bool enabled bool
generated bool generated bool
exitNonZero bool
} }
gosimple struct { gosimple struct {
enabled bool enabled bool
generated bool generated bool
exitNonZero bool
} }
unused struct { unused struct {
enabled bool enabled bool
@ -51,6 +32,7 @@ func main() {
debug string debug string
wholeProgram bool wholeProgram bool
reflection bool reflection bool
exitNonZero bool
} }
} }
fs := lintutil.FlagSet("megacheck") fs := lintutil.FlagSet("megacheck")
@ -58,11 +40,15 @@ func main() {
"simple.enabled", true, "Run gosimple") "simple.enabled", true, "Run gosimple")
fs.BoolVar(&flags.gosimple.generated, fs.BoolVar(&flags.gosimple.generated,
"simple.generated", false, "Check generated code") "simple.generated", false, "Check generated code")
fs.BoolVar(&flags.gosimple.exitNonZero,
"simple.exit-non-zero", false, "Exit non-zero if any problems were found")
fs.BoolVar(&flags.staticcheck.enabled, fs.BoolVar(&flags.staticcheck.enabled,
"staticcheck.enabled", true, "Run staticcheck") "staticcheck.enabled", true, "Run staticcheck")
fs.BoolVar(&flags.staticcheck.generated, fs.BoolVar(&flags.staticcheck.generated,
"staticcheck.generated", false, "Check generated code (only applies to a subset of checks)") "staticcheck.generated", false, "Check generated code (only applies to a subset of checks)")
fs.BoolVar(&flags.staticcheck.exitNonZero,
"staticcheck.exit-non-zero", true, "Exit non-zero if any problems were found")
fs.BoolVar(&flags.unused.enabled, fs.BoolVar(&flags.unused.enabled,
"unused.enabled", true, "Run unused") "unused.enabled", true, "Run unused")
@ -78,22 +64,31 @@ func main() {
"unused.vars", true, "Report unused variables") "unused.vars", true, "Report unused variables")
fs.BoolVar(&flags.unused.wholeProgram, fs.BoolVar(&flags.unused.wholeProgram,
"unused.exported", false, "Treat arguments as a program and report unused exported identifiers") "unused.exported", false, "Treat arguments as a program and report unused exported identifiers")
fs.BoolVar(&flags.unused.reflection, "unused.reflect", true, "Consider identifiers as used when it's likely they'll be accessed via reflection") fs.BoolVar(&flags.unused.reflection,
"unused.reflect", true, "Consider identifiers as used when it's likely they'll be accessed via reflection")
fs.BoolVar(&flags.unused.exitNonZero,
"unused.exit-non-zero", true, "Exit non-zero if any problems were found")
fs.Parse(os.Args[1:]) fs.Parse(os.Args[1:])
c := &Checker{} var checkers []lintutil.CheckerConfig
if flags.staticcheck.enabled { if flags.staticcheck.enabled {
sac := staticcheck.NewChecker() sac := staticcheck.NewChecker()
sac.CheckGenerated = flags.staticcheck.generated sac.CheckGenerated = flags.staticcheck.generated
c.Checkers = append(c.Checkers, sac) checkers = append(checkers, lintutil.CheckerConfig{
Checker: sac,
ExitNonZero: flags.staticcheck.exitNonZero,
})
} }
if flags.gosimple.enabled { if flags.gosimple.enabled {
sc := simple.NewChecker() sc := simple.NewChecker()
sc.CheckGenerated = flags.gosimple.generated sc.CheckGenerated = flags.gosimple.generated
c.Checkers = append(c.Checkers, sc) checkers = append(checkers, lintutil.CheckerConfig{
Checker: sc,
ExitNonZero: flags.gosimple.exitNonZero,
})
} }
if flags.unused.enabled { if flags.unused.enabled {
@ -116,8 +111,12 @@ func main() {
uc := unused.NewChecker(mode) uc := unused.NewChecker(mode)
uc.WholeProgram = flags.unused.wholeProgram uc.WholeProgram = flags.unused.wholeProgram
uc.ConsiderReflection = flags.unused.reflection uc.ConsiderReflection = flags.unused.reflection
c.Checkers = append(c.Checkers, unused.NewLintChecker(uc)) checkers = append(checkers, lintutil.CheckerConfig{
Checker: unused.NewLintChecker(uc),
ExitNonZero: flags.unused.exitNonZero,
})
} }
lintutil.ProcessFlagSet(c, fs) lintutil.ProcessFlagSet(checkers, fs)
} }

View file

@ -0,0 +1,86 @@
// rdeps scans GOPATH for all reverse dependencies of a set of Go
// packages.
//
// rdeps will not sort its output, and the order of the output is
// undefined. Pipe its output through sort if you need stable output.
package main
import (
"bufio"
"flag"
"fmt"
"go/build"
"os"
"honnef.co/go/tools/version"
"github.com/kisielk/gotool"
"golang.org/x/tools/go/buildutil"
"golang.org/x/tools/refactor/importgraph"
)
func main() {
var tags buildutil.TagsFlag
flag.Var(&tags, "tags", "List of build tags")
stdin := flag.Bool("stdin", false, "Read packages from stdin instead of the command line")
recursive := flag.Bool("r", false, "Print reverse dependencies recursively")
printVersion := flag.Bool("version", false, "Print version and exit")
flag.Parse()
if *printVersion {
version.Print()
os.Exit(0)
}
ctx := build.Default
ctx.BuildTags = tags
var args []string
if *stdin {
s := bufio.NewScanner(os.Stdin)
for s.Scan() {
args = append(args, s.Text())
}
} else {
args = flag.Args()
}
if len(args) == 0 {
return
}
wd, err := os.Getwd()
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
pkgs := gotool.ImportPaths(args)
for i, pkg := range pkgs {
bpkg, err := ctx.Import(pkg, wd, build.FindOnly)
if err != nil {
continue
}
pkgs[i] = bpkg.ImportPath
}
_, reverse, errors := importgraph.Build(&ctx)
_ = errors
seen := map[string]bool{}
var printRDeps func(pkg string)
printRDeps = func(pkg string) {
for rdep := range reverse[pkg] {
if seen[rdep] {
continue
}
seen[rdep] = true
fmt.Println(rdep)
if *recursive {
printRDeps(rdep)
}
}
}
for _, pkg := range pkgs {
printRDeps(pkg)
}
for pkg, err := range errors {
fmt.Fprintf(os.Stderr, "error in package %s: %s\n", pkg, err)
}
}

View file

@ -1,20 +0,0 @@
Copyright (c) 2016 Dominik Honnef
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -15,5 +15,9 @@ func main() {
fs.Parse(os.Args[1:]) fs.Parse(os.Args[1:])
c := staticcheck.NewChecker() c := staticcheck.NewChecker()
c.CheckGenerated = *gen c.CheckGenerated = *gen
lintutil.ProcessFlagSet(c, fs) cfg := lintutil.CheckerConfig{
Checker: c,
ExitNonZero: true,
}
lintutil.ProcessFlagSet([]lintutil.CheckerConfig{cfg}, fs)
} }

View file

@ -0,0 +1,205 @@
// structlayout-optimize reorders struct fields to minimize the amount
// of padding.
package main
import (
"encoding/json"
"flag"
"fmt"
"log"
"os"
"sort"
"strings"
st "honnef.co/go/tools/structlayout"
"honnef.co/go/tools/version"
)
var (
fJSON bool
fRecurse bool
fVersion bool
)
func init() {
flag.BoolVar(&fJSON, "json", false, "Format data as JSON")
flag.BoolVar(&fRecurse, "r", false, "Break up structs and reorder their fields freely")
flag.BoolVar(&fVersion, "version", false, "Print version and exit")
}
func main() {
log.SetFlags(0)
flag.Parse()
if fVersion {
version.Print()
os.Exit(0)
}
var in []st.Field
if err := json.NewDecoder(os.Stdin).Decode(&in); err != nil {
log.Fatal(err)
}
if len(in) == 0 {
return
}
if !fRecurse {
in = combine(in)
}
var fields []st.Field
for _, field := range in {
if field.IsPadding {
continue
}
fields = append(fields, field)
}
optimize(fields)
fields = pad(fields)
if fJSON {
json.NewEncoder(os.Stdout).Encode(fields)
} else {
for _, field := range fields {
fmt.Println(field)
}
}
}
func combine(fields []st.Field) []st.Field {
new := st.Field{}
cur := ""
var out []st.Field
wasPad := true
for _, field := range fields {
var prefix string
if field.IsPadding {
wasPad = true
continue
}
p := strings.Split(field.Name, ".")
prefix = strings.Join(p[:2], ".")
if field.Align > new.Align {
new.Align = field.Align
}
if !wasPad {
new.End = field.Start
new.Size = new.End - new.Start
}
if prefix != cur {
if cur != "" {
out = append(out, new)
}
cur = prefix
new = field
new.Name = prefix
} else {
new.Type = "struct"
}
wasPad = false
}
new.Size = new.End - new.Start
out = append(out, new)
return out
}
func optimize(fields []st.Field) {
sort.Sort(&byAlignAndSize{fields})
}
func pad(fields []st.Field) []st.Field {
if len(fields) == 0 {
return nil
}
var out []st.Field
pos := int64(0)
offsets := offsetsof(fields)
alignment := int64(1)
for i, field := range fields {
if field.Align > alignment {
alignment = field.Align
}
if offsets[i] > pos {
padding := offsets[i] - pos
out = append(out, st.Field{
IsPadding: true,
Start: pos,
End: pos + padding,
Size: padding,
})
pos += padding
}
field.Start = pos
field.End = pos + field.Size
out = append(out, field)
pos += field.Size
}
sz := size(out)
pad := align(sz, alignment) - sz
if pad > 0 {
field := out[len(out)-1]
out = append(out, st.Field{
IsPadding: true,
Start: field.End,
End: field.End + pad,
Size: pad,
})
}
return out
}
func size(fields []st.Field) int64 {
n := int64(0)
for _, field := range fields {
n += field.Size
}
return n
}
type byAlignAndSize struct {
fields []st.Field
}
func (s *byAlignAndSize) Len() int { return len(s.fields) }
func (s *byAlignAndSize) Swap(i, j int) {
s.fields[i], s.fields[j] = s.fields[j], s.fields[i]
}
func (s *byAlignAndSize) Less(i, j int) bool {
// Place zero sized objects before non-zero sized objects.
if s.fields[i].Size == 0 && s.fields[j].Size != 0 {
return true
}
if s.fields[j].Size == 0 && s.fields[i].Size != 0 {
return false
}
// Next, place more tightly aligned objects before less tightly aligned objects.
if s.fields[i].Align != s.fields[j].Align {
return s.fields[i].Align > s.fields[j].Align
}
// Lastly, order by size.
if s.fields[i].Size != s.fields[j].Size {
return s.fields[i].Size > s.fields[j].Size
}
return false
}
func offsetsof(fields []st.Field) []int64 {
offsets := make([]int64, len(fields))
var o int64
for i, f := range fields {
a := f.Align
o = align(o, a)
offsets[i] = o
o += f.Size
}
return offsets
}
// align returns the smallest y >= x such that y % a == 0.
func align(x, a int64) int64 {
y := x + a - 1
return y - y%a
}

View file

@ -0,0 +1,72 @@
// structlayout-pretty formats the output of structlayout with ASCII
// art.
package main
import (
"encoding/json"
"flag"
"fmt"
"log"
"os"
"strings"
st "honnef.co/go/tools/structlayout"
"honnef.co/go/tools/version"
)
var (
fVerbose bool
fVersion bool
)
func init() {
flag.BoolVar(&fVerbose, "v", false, "Do not compact consecutive bytes of fields")
flag.BoolVar(&fVersion, "version", false, "Print version and exit")
}
func main() {
log.SetFlags(0)
flag.Parse()
if fVersion {
version.Print()
os.Exit(0)
}
var fields []st.Field
if err := json.NewDecoder(os.Stdin).Decode(&fields); err != nil {
log.Fatal(err)
}
if len(fields) == 0 {
return
}
max := fields[len(fields)-1].End
maxLength := len(fmt.Sprintf("%d", max))
padding := strings.Repeat(" ", maxLength+2)
format := fmt.Sprintf(" %%%dd ", maxLength)
pos := int64(0)
fmt.Println(padding + "+--------+")
for _, field := range fields {
name := field.Name + " " + field.Type
if field.IsPadding {
name = "padding"
}
fmt.Printf(format+"| | <- %s (size %d, align %d)\n", pos, name, field.Size, field.Align)
fmt.Println(padding + "+--------+")
if fVerbose {
for i := int64(0); i < field.Size-1; i++ {
fmt.Printf(format+"| |\n", pos+i+1)
fmt.Println(padding + "+--------+")
}
} else {
if field.Size > 2 {
fmt.Println(padding + "-........-")
fmt.Println(padding + "+--------+")
fmt.Printf(format+"| |\n", pos+field.Size-1)
fmt.Println(padding + "+--------+")
}
}
pos += field.Size
}
}

View file

@ -0,0 +1,149 @@
// structlayout displays the layout (field sizes and padding) of structs.
package main
import (
"encoding/json"
"flag"
"fmt"
"go/build"
"go/types"
"log"
"os"
"honnef.co/go/tools/gcsizes"
st "honnef.co/go/tools/structlayout"
"honnef.co/go/tools/version"
"golang.org/x/tools/go/loader"
)
var (
fJSON bool
fVersion bool
)
func init() {
flag.BoolVar(&fJSON, "json", false, "Format data as JSON")
flag.BoolVar(&fVersion, "version", false, "Print version and exit")
}
func main() {
log.SetFlags(0)
flag.Parse()
if fVersion {
version.Print()
os.Exit(0)
}
if len(flag.Args()) != 2 {
flag.Usage()
os.Exit(1)
}
conf := loader.Config{
Build: &build.Default,
}
var pkg string
var typName string
pkg = flag.Args()[0]
typName = flag.Args()[1]
conf.Import(pkg)
lprog, err := conf.Load()
if err != nil {
log.Fatal(err)
}
var typ types.Type
obj := lprog.Package(pkg).Pkg.Scope().Lookup(typName)
if obj == nil {
log.Fatal("couldn't find type")
}
typ = obj.Type()
st, ok := typ.Underlying().(*types.Struct)
if !ok {
log.Fatal("identifier is not a struct type")
}
fields := sizes(st, typ.(*types.Named).Obj().Name(), 0, nil)
if fJSON {
emitJSON(fields)
} else {
emitText(fields)
}
}
func emitJSON(fields []st.Field) {
if fields == nil {
fields = []st.Field{}
}
json.NewEncoder(os.Stdout).Encode(fields)
}
func emitText(fields []st.Field) {
for _, field := range fields {
fmt.Println(field)
}
}
func sizes(typ *types.Struct, prefix string, base int64, out []st.Field) []st.Field {
s := gcsizes.ForArch(build.Default.GOARCH)
n := typ.NumFields()
var fields []*types.Var
for i := 0; i < n; i++ {
fields = append(fields, typ.Field(i))
}
offsets := s.Offsetsof(fields)
for i := range offsets {
offsets[i] += base
}
pos := base
for i, field := range fields {
if offsets[i] > pos {
padding := offsets[i] - pos
out = append(out, st.Field{
IsPadding: true,
Start: pos,
End: pos + padding,
Size: padding,
})
pos += padding
}
size := s.Sizeof(field.Type())
if typ2, ok := field.Type().Underlying().(*types.Struct); ok && typ2.NumFields() != 0 {
out = sizes(typ2, prefix+"."+field.Name(), pos, out)
} else {
out = append(out, st.Field{
Name: prefix + "." + field.Name(),
Type: field.Type().String(),
Start: offsets[i],
End: offsets[i] + size,
Size: size,
Align: s.Alignof(field.Type()),
})
}
pos += size
}
if len(out) == 0 {
return out
}
field := &out[len(out)-1]
if field.Size == 0 {
field.Size = 1
field.End++
}
pad := s.Sizeof(typ) - field.End
if pad > 0 {
out = append(out, st.Field{
IsPadding: true,
Start: field.End,
End: field.End + pad,
Size: pad,
})
}
return out
}

View file

@ -1,20 +0,0 @@
Copyright (c) 2016 Dominik Honnef
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -70,5 +70,9 @@ func main() {
checker := newChecker(mode) checker := newChecker(mode)
l := unused.NewLintChecker(checker) l := unused.NewLintChecker(checker)
lintutil.ProcessFlagSet(l, fs) cfg := lintutil.CheckerConfig{
Checker: l,
ExitNonZero: true,
}
lintutil.ProcessFlagSet([]lintutil.CheckerConfig{cfg}, fs)
} }

View file

@ -0,0 +1,54 @@
package deprecated
type Deprecation struct {
DeprecatedSince int
AlternativeAvailableSince int
}
var Stdlib = map[string]Deprecation{
"image/jpeg.Reader": {4, 0},
// FIXME(dh): AllowBinary isn't being detected as deprecated
// because the comment has a newline right after "Deprecated:"
"go/build.AllowBinary": {7, 7},
"(archive/zip.FileHeader).CompressedSize": {1, 1},
"(archive/zip.FileHeader).UncompressedSize": {1, 1},
"(go/doc.Package).Bugs": {1, 1},
"os.SEEK_SET": {7, 7},
"os.SEEK_CUR": {7, 7},
"os.SEEK_END": {7, 7},
"(net.Dialer).Cancel": {7, 7},
"runtime.CPUProfile": {9, 0},
"compress/flate.ReadError": {6, 6},
"compress/flate.WriteError": {6, 6},
"path/filepath.HasPrefix": {0, 0},
"(net/http.Transport).Dial": {7, 7},
"(*net/http.Transport).CancelRequest": {6, 5},
"net/http.ErrWriteAfterFlush": {7, 0},
"net/http.ErrHeaderTooLong": {8, 0},
"net/http.ErrShortBody": {8, 0},
"net/http.ErrMissingContentLength": {8, 0},
"net/http/httputil.ErrPersistEOF": {0, 0},
"net/http/httputil.ErrClosed": {0, 0},
"net/http/httputil.ErrPipeline": {0, 0},
"net/http/httputil.ServerConn": {0, 0},
"net/http/httputil.NewServerConn": {0, 0},
"net/http/httputil.ClientConn": {0, 0},
"net/http/httputil.NewClientConn": {0, 0},
"net/http/httputil.NewProxyClientConn": {0, 0},
"(net/http.Request).Cancel": {7, 7},
"(text/template/parse.PipeNode).Line": {1, 1},
"(text/template/parse.ActionNode).Line": {1, 1},
"(text/template/parse.BranchNode).Line": {1, 1},
"(text/template/parse.TemplateNode).Line": {1, 1},
"database/sql/driver.ColumnConverter": {9, 9},
"database/sql/driver.Execer": {8, 8},
"database/sql/driver.Queryer": {8, 8},
"(database/sql/driver.Conn).Begin": {8, 8},
"(database/sql/driver.Stmt).Exec": {8, 8},
"(database/sql/driver.Stmt).Query": {8, 8},
"syscall.StringByteSlice": {1, 1},
"syscall.StringBytePtr": {1, 1},
"syscall.StringSlicePtr": {1, 1},
"syscall.StringToUTF16": {1, 1},
"syscall.StringToUTF16Ptr": {1, 1},
}

View file

@ -0,0 +1,157 @@
package errcheck
import (
"go/types"
"honnef.co/go/tools/functions"
"honnef.co/go/tools/lint"
"honnef.co/go/tools/ssa"
)
type Checker struct {
funcDescs *functions.Descriptions
}
func NewChecker() *Checker {
return &Checker{}
}
func (*Checker) Name() string { return "errcheck" }
func (*Checker) Prefix() string { return "ERR" }
func (c *Checker) Funcs() map[string]lint.Func {
return map[string]lint.Func{
"ERR1000": c.CheckErrcheck,
}
}
func (c *Checker) Init(prog *lint.Program) {
c.funcDescs = functions.NewDescriptions(prog.SSA)
}
func (c *Checker) CheckErrcheck(j *lint.Job) {
for _, ssafn := range j.Program.InitialFunctions {
for _, b := range ssafn.Blocks {
for _, ins := range b.Instrs {
ssacall, ok := ins.(ssa.CallInstruction)
if !ok {
continue
}
switch lint.CallName(ssacall.Common()) {
case "fmt.Print", "fmt.Println", "fmt.Printf":
continue
}
isRecover := false
if builtin, ok := ssacall.Common().Value.(*ssa.Builtin); ok {
isRecover = ok && builtin.Name() == "recover"
}
switch ins := ins.(type) {
case ssa.Value:
refs := ins.Referrers()
if refs == nil || len(lint.FilterDebug(*refs)) != 0 {
continue
}
case ssa.Instruction:
// will be a 'go' or 'defer', neither of which has usable return values
default:
// shouldn't happen
continue
}
if ssacall.Common().IsInvoke() {
if sc, ok := ssacall.Common().Value.(*ssa.Call); ok {
// TODO(dh): support multiple levels of
// interfaces, not just one
ssafn := sc.Common().StaticCallee()
if ssafn != nil {
ct := c.funcDescs.Get(ssafn).ConcreteReturnTypes
// TODO(dh): support >1 concrete types
if ct != nil && len(ct) == 1 {
// TODO(dh): do we have access to a
// cached method set somewhere?
ms := types.NewMethodSet(ct[0].At(ct[0].Len() - 1).Type())
// TODO(dh): where can we get the pkg
// for Lookup? Passing nil works fine
// for exported methods, but will fail
// on unexported ones
// TODO(dh): holy nesting and poor
// variable names, clean this up
fn, _ := ms.Lookup(nil, ssacall.Common().Method.Name()).Obj().(*types.Func)
if fn != nil {
ssafn := j.Program.SSA.FuncValue(fn)
if ssafn != nil {
if c.funcDescs.Get(ssafn).NilError {
continue
}
}
}
}
}
}
} else {
ssafn := ssacall.Common().StaticCallee()
if ssafn != nil {
if c.funcDescs.Get(ssafn).NilError {
// Don't complain when the error is known to be nil
continue
}
}
}
switch lint.CallName(ssacall.Common()) {
case "(*os.File).Close":
recv := ssacall.Common().Args[0]
if isReadOnlyFile(recv, nil) {
continue
}
}
res := ssacall.Common().Signature().Results()
if res.Len() == 0 {
continue
}
if !isRecover {
last := res.At(res.Len() - 1)
if types.TypeString(last.Type(), nil) != "error" {
continue
}
}
j.Errorf(ins, "unchecked error")
}
}
}
}
func isReadOnlyFile(val ssa.Value, seen map[ssa.Value]bool) bool {
if seen == nil {
seen = map[ssa.Value]bool{}
}
if seen[val] {
return true
}
seen[val] = true
switch val := val.(type) {
case *ssa.Phi:
for _, edge := range val.Edges {
if !isReadOnlyFile(edge, seen) {
return false
}
}
return true
case *ssa.Extract:
call, ok := val.Tuple.(*ssa.Call)
if !ok {
return false
}
switch lint.CallName(call.Common()) {
case "os.Open":
return true
case "os.OpenFile":
flags, ok := call.Common().Args[1].(*ssa.Const)
return ok && flags.Uint64() == 0
}
return false
}
return false
}

View file

@ -1,20 +0,0 @@
Copyright (c) 2016 Dominik Honnef
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -48,6 +48,8 @@ var stdlibDescs = map[string]Description{
type Description struct { type Description struct {
// The function is known to be pure // The function is known to be pure
Pure bool Pure bool
// The function is known to be a stub
Stub bool
// The function is known to never return (panics notwithstanding) // The function is known to never return (panics notwithstanding)
Infinite bool Infinite bool
// Variable ranges // Variable ranges
@ -90,6 +92,7 @@ func (d *Descriptions) Get(fn *ssa.Function) Description {
{ {
fd.result = stdlibDescs[fn.RelString(nil)] fd.result = stdlibDescs[fn.RelString(nil)]
fd.result.Pure = fd.result.Pure || d.IsPure(fn) fd.result.Pure = fd.result.Pure || d.IsPure(fn)
fd.result.Stub = fd.result.Stub || d.IsStub(fn)
fd.result.Infinite = fd.result.Infinite || !terminates(fn) fd.result.Infinite = fd.result.Infinite || !terminates(fn)
fd.result.Ranges = vrp.BuildGraph(fn).Solve() fd.result.Ranges = vrp.BuildGraph(fn).Solve()
fd.result.Loops = findLoops(fn) fd.result.Loops = findLoops(fn)

View file

@ -5,9 +5,41 @@ import (
"go/types" "go/types"
"honnef.co/go/tools/callgraph" "honnef.co/go/tools/callgraph"
"honnef.co/go/tools/lint"
"honnef.co/go/tools/ssa" "honnef.co/go/tools/ssa"
) )
// IsStub reports whether a function is a stub. A function is
// considered a stub if it has no instructions or exactly one
// instruction, which must be either returning only constant values or
// a panic.
func (d *Descriptions) IsStub(fn *ssa.Function) bool {
if len(fn.Blocks) == 0 {
return true
}
if len(fn.Blocks) > 1 {
return false
}
instrs := lint.FilterDebug(fn.Blocks[0].Instrs)
if len(instrs) != 1 {
return false
}
switch instrs[0].(type) {
case *ssa.Return:
// Since this is the only instruction, the return value must
// be a constant. We consider all constants as stubs, not just
// the zero value. This does not, unfortunately, cover zero
// initialised structs, as these cause additional
// instructions.
return true
case *ssa.Panic:
return true
default:
return false
}
}
func (d *Descriptions) IsPure(fn *ssa.Function) bool { func (d *Descriptions) IsPure(fn *ssa.Function) bool {
if fn.Signature.Results().Len() == 0 { if fn.Signature.Results().Len() == 0 {
// A function with no return values is empty or is doing some // A function with no return values is empty or is doing some

View file

@ -1,20 +0,0 @@
Copyright (c) 2016 Dominik Honnef
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -1,20 +0,0 @@
Copyright (c) 2016 Dominik Honnef
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -1,20 +0,0 @@
Copyright (c) 2016 Dominik Honnef
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -11,6 +11,7 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"go/ast" "go/ast"
"go/build"
"go/constant" "go/constant"
"go/printer" "go/printer"
"go/token" "go/token"
@ -20,6 +21,7 @@ import (
"sort" "sort"
"strings" "strings"
"sync" "sync"
"unicode"
"golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/loader" "golang.org/x/tools/go/loader"
@ -30,15 +32,85 @@ import (
type Job struct { type Job struct {
Program *Program Program *Program
checker string
check string check string
problems []Problem problems []Problem
} }
type Ignore struct { type Ignore interface {
Match(p Problem) bool
}
type LineIgnore struct {
File string
Line int
Checks []string
matched bool
pos token.Pos
}
func (li *LineIgnore) Match(p Problem) bool {
if p.Position.Filename != li.File || p.Position.Line != li.Line {
return false
}
for _, c := range li.Checks {
if m, _ := filepath.Match(c, p.Check); m {
li.matched = true
return true
}
}
return false
}
func (li *LineIgnore) String() string {
matched := "not matched"
if li.matched {
matched = "matched"
}
return fmt.Sprintf("%s:%d %s (%s)", li.File, li.Line, strings.Join(li.Checks, ", "), matched)
}
type FileIgnore struct {
File string
Checks []string
}
func (fi *FileIgnore) Match(p Problem) bool {
if p.Position.Filename != fi.File {
return false
}
for _, c := range fi.Checks {
if m, _ := filepath.Match(c, p.Check); m {
return true
}
}
return false
}
type GlobIgnore struct {
Pattern string Pattern string
Checks []string Checks []string
} }
func (gi *GlobIgnore) Match(p Problem) bool {
if gi.Pattern != "*" {
pkgpath := p.Package.Path()
if strings.HasSuffix(pkgpath, "_test") {
pkgpath = pkgpath[:len(pkgpath)-len("_test")]
}
name := filepath.Join(pkgpath, filepath.Base(p.Position.Filename))
if m, _ := filepath.Match(gi.Pattern, name); !m {
return false
}
}
for _, c := range gi.Checks {
if m, _ := filepath.Match(c, p.Check); m {
return true
}
}
return false
}
type Program struct { type Program struct {
SSA *ssa.Program SSA *ssa.Program
Prog *loader.Program Prog *loader.Program
@ -58,51 +130,70 @@ type Func func(*Job)
// Problem represents a problem in some source code. // Problem represents a problem in some source code.
type Problem struct { type Problem struct {
Position token.Pos // position in source file pos token.Pos
Text string // the prose that describes the problem Position token.Position // position in source file
Text string // the prose that describes the problem
Check string
Checker string
Package *types.Package
Ignored bool
} }
func (p *Problem) String() string { func (p *Problem) String() string {
return p.Text if p.Check == "" {
return p.Text
}
return fmt.Sprintf("%s (%s)", p.Text, p.Check)
} }
type Checker interface { type Checker interface {
Name() string
Prefix() string
Init(*Program) Init(*Program)
Funcs() map[string]Func Funcs() map[string]Func
} }
// A Linter lints Go source code. // A Linter lints Go source code.
type Linter struct { type Linter struct {
Checker Checker Checker Checker
Ignores []Ignore Ignores []Ignore
GoVersion int GoVersion int
ReturnIgnored bool
automaticIgnores []Ignore
} }
func (l *Linter) ignore(j *Job, p Problem) bool { func (l *Linter) ignore(p Problem) bool {
tf := j.Program.SSA.Fset.File(p.Position) ignored := false
f := j.Program.tokenFileMap[tf] for _, ig := range l.automaticIgnores {
pkg := j.Program.astFileMap[f].Pkg // We cannot short-circuit these, as we want to record, for
// each ignore, whether it matched or not.
for _, ig := range l.Ignores { if ig.Match(p) {
pkgpath := pkg.Path() ignored = true
if strings.HasSuffix(pkgpath, "_test") {
pkgpath = pkgpath[:len(pkgpath)-len("_test")]
}
name := filepath.Join(pkgpath, filepath.Base(tf.Name()))
if m, _ := filepath.Match(ig.Pattern, name); !m {
continue
}
for _, c := range ig.Checks {
if m, _ := filepath.Match(c, j.check); m {
return true
}
} }
} }
if ignored {
// no need to execute other ignores if we've already had a
// match.
return true
}
for _, ig := range l.Ignores {
// We can short-circuit here, as we aren't tracking any
// information.
if ig.Match(p) {
return true
}
}
return false return false
} }
func (prog *Program) File(node Positioner) *ast.File {
return prog.tokenFileMap[prog.SSA.Fset.File(node.Pos())]
}
func (j *Job) File(node Positioner) *ast.File { func (j *Job) File(node Positioner) *ast.File {
return j.Program.tokenFileMap[j.Program.SSA.Fset.File(node.Pos())] return j.Program.File(node)
} }
// TODO(dh): switch to sort.Slice when Go 1.9 lands. // TODO(dh): switch to sort.Slice when Go 1.9 lands.
@ -116,7 +207,7 @@ func (ps byPosition) Len() int {
} }
func (ps byPosition) Less(i int, j int) bool { func (ps byPosition) Less(i int, j int) bool {
pi, pj := ps.fset.Position(ps.ps[i].Position), ps.fset.Position(ps.ps[j].Position) pi, pj := ps.ps[i].Position, ps.ps[j].Position
if pi.Filename != pj.Filename { if pi.Filename != pj.Filename {
return pi.Filename < pj.Filename return pi.Filename < pj.Filename
@ -135,16 +226,40 @@ func (ps byPosition) Swap(i int, j int) {
ps.ps[i], ps.ps[j] = ps.ps[j], ps.ps[i] ps.ps[i], ps.ps[j] = ps.ps[j], ps.ps[i]
} }
func (l *Linter) Lint(lprog *loader.Program) []Problem { func parseDirective(s string) (cmd string, args []string) {
if !strings.HasPrefix(s, "//lint:") {
return "", nil
}
s = strings.TrimPrefix(s, "//lint:")
fields := strings.Split(s, " ")
return fields[0], fields[1:]
}
func (l *Linter) Lint(lprog *loader.Program, conf *loader.Config) []Problem {
ssaprog := ssautil.CreateProgram(lprog, ssa.GlobalDebug) ssaprog := ssautil.CreateProgram(lprog, ssa.GlobalDebug)
ssaprog.Build() ssaprog.Build()
pkgMap := map[*ssa.Package]*Pkg{} pkgMap := map[*ssa.Package]*Pkg{}
var pkgs []*Pkg var pkgs []*Pkg
for _, pkginfo := range lprog.InitialPackages() { for _, pkginfo := range lprog.InitialPackages() {
ssapkg := ssaprog.Package(pkginfo.Pkg) ssapkg := ssaprog.Package(pkginfo.Pkg)
var bp *build.Package
if len(pkginfo.Files) != 0 {
path := lprog.Fset.Position(pkginfo.Files[0].Pos()).Filename
dir := filepath.Dir(path)
var err error
ctx := conf.Build
if ctx == nil {
ctx = &build.Default
}
bp, err = ctx.ImportDir(dir, 0)
if err != nil {
// shouldn't happen
}
}
pkg := &Pkg{ pkg := &Pkg{
Package: ssapkg, Package: ssapkg,
Info: pkginfo, Info: pkginfo,
BuildPkg: bp,
} }
pkgMap[ssapkg] = pkg pkgMap[ssapkg] = pkg
pkgs = append(pkgs, pkg) pkgs = append(pkgs, pkg)
@ -158,6 +273,7 @@ func (l *Linter) Lint(lprog *loader.Program) []Problem {
tokenFileMap: map[*token.File]*ast.File{}, tokenFileMap: map[*token.File]*ast.File{},
astFileMap: map[*ast.File]*Pkg{}, astFileMap: map[*ast.File]*Pkg{},
} }
initial := map[*types.Package]struct{}{} initial := map[*types.Package]struct{}{}
for _, pkg := range pkgs { for _, pkg := range pkgs {
initial[pkg.Info.Pkg] = struct{}{} initial[pkg.Info.Pkg] = struct{}{}
@ -176,9 +292,69 @@ func (l *Linter) Lint(lprog *loader.Program) []Problem {
ssapkg := ssaprog.Package(pkg.Info.Pkg) ssapkg := ssaprog.Package(pkg.Info.Pkg)
for _, f := range pkg.Info.Files { for _, f := range pkg.Info.Files {
prog.astFileMap[f] = pkgMap[ssapkg]
}
}
for _, pkginfo := range lprog.AllPackages {
for _, f := range pkginfo.Files {
tf := lprog.Fset.File(f.Pos()) tf := lprog.Fset.File(f.Pos())
prog.tokenFileMap[tf] = f prog.tokenFileMap[tf] = f
prog.astFileMap[f] = pkgMap[ssapkg] }
}
var out []Problem
l.automaticIgnores = nil
for _, pkginfo := range lprog.InitialPackages() {
for _, f := range pkginfo.Files {
cm := ast.NewCommentMap(lprog.Fset, f, f.Comments)
for node, cgs := range cm {
for _, cg := range cgs {
for _, c := range cg.List {
if !strings.HasPrefix(c.Text, "//lint:") {
continue
}
cmd, args := parseDirective(c.Text)
switch cmd {
case "ignore", "file-ignore":
if len(args) < 2 {
// FIXME(dh): this causes duplicated warnings when using megacheck
p := Problem{
pos: c.Pos(),
Position: prog.DisplayPosition(c.Pos()),
Text: "malformed linter directive; missing the required reason field?",
Check: "",
Checker: l.Checker.Name(),
Package: nil,
}
out = append(out, p)
continue
}
default:
// unknown directive, ignore
continue
}
checks := strings.Split(args[0], ",")
pos := prog.DisplayPosition(node.Pos())
var ig Ignore
switch cmd {
case "ignore":
ig = &LineIgnore{
File: pos.Filename,
Line: pos.Line,
Checks: checks,
pos: c.Pos(),
}
case "file-ignore":
ig = &FileIgnore{
File: pos.Filename,
Checks: checks,
}
}
l.automaticIgnores = append(l.automaticIgnores, ig)
}
}
}
} }
} }
@ -237,6 +413,7 @@ func (l *Linter) Lint(lprog *loader.Program) []Problem {
for _, k := range keys { for _, k := range keys {
j := &Job{ j := &Job{
Program: prog, Program: prog,
checker: l.Checker.Name(),
check: k, check: k,
} }
jobs = append(jobs, j) jobs = append(jobs, j)
@ -255,15 +432,47 @@ func (l *Linter) Lint(lprog *loader.Program) []Problem {
} }
wg.Wait() wg.Wait()
var out []Problem
for _, j := range jobs { for _, j := range jobs {
for _, p := range j.problems { for _, p := range j.problems {
if !l.ignore(j, p) { p.Ignored = l.ignore(p)
if l.ReturnIgnored || !p.Ignored {
out = append(out, p) out = append(out, p)
} }
} }
} }
for _, ig := range l.automaticIgnores {
ig, ok := ig.(*LineIgnore)
if !ok {
continue
}
if ig.matched {
continue
}
for _, c := range ig.Checks {
idx := strings.IndexFunc(c, func(r rune) bool {
return unicode.IsNumber(r)
})
if idx == -1 {
// malformed check name, backing out
continue
}
if c[:idx] != l.Checker.Prefix() {
// not for this checker
continue
}
p := Problem{
pos: ig.pos,
Position: prog.DisplayPosition(ig.pos),
Text: "this linter directive didn't match anything; should it be removed?",
Check: "",
Checker: l.Checker.Name(),
Package: nil,
}
out = append(out, p)
}
}
sort.Sort(byPosition{lprog.Fset, out}) sort.Sort(byPosition{lprog.Fset, out})
return out return out
} }
@ -271,7 +480,8 @@ func (l *Linter) Lint(lprog *loader.Program) []Problem {
// Pkg represents a package being linted. // Pkg represents a package being linted.
type Pkg struct { type Pkg struct {
*ssa.Package *ssa.Package
Info *loader.PackageInfo Info *loader.PackageInfo
BuildPkg *build.Package
} }
type packager interface { type packager interface {
@ -309,10 +519,55 @@ type Positioner interface {
Pos() token.Pos Pos() token.Pos
} }
func (prog *Program) DisplayPosition(p token.Pos) token.Position {
// The //line compiler directive can be used to change the file
// name and line numbers associated with code. This can, for
// example, be used by code generation tools. The most prominent
// example is 'go tool cgo', which uses //line directives to refer
// back to the original source code.
//
// In the context of our linters, we need to treat these
// directives differently depending on context. For cgo files, we
// want to honour the directives, so that line numbers are
// adjusted correctly. For all other files, we want to ignore the
// directives, so that problems are reported at their actual
// position and not, for example, a yacc grammar file. This also
// affects the ignore mechanism, since it operates on the position
// information stored within problems. With this implementation, a
// user will ignore foo.go, not foo.y
pkg := prog.astFileMap[prog.tokenFileMap[prog.Prog.Fset.File(p)]]
bp := pkg.BuildPkg
adjPos := prog.Prog.Fset.Position(p)
if bp == nil {
// couldn't find the package for some reason (deleted? faulty
// file system?)
return adjPos
}
base := filepath.Base(adjPos.Filename)
for _, f := range bp.CgoFiles {
if f == base {
// this is a cgo file, use the adjusted position
return adjPos
}
}
// not a cgo file, ignore //line directives
return prog.Prog.Fset.PositionFor(p, false)
}
func (j *Job) Errorf(n Positioner, format string, args ...interface{}) *Problem { func (j *Job) Errorf(n Positioner, format string, args ...interface{}) *Problem {
tf := j.Program.SSA.Fset.File(n.Pos())
f := j.Program.tokenFileMap[tf]
pkg := j.Program.astFileMap[f].Pkg
pos := j.Program.DisplayPosition(n.Pos())
problem := Problem{ problem := Problem{
Position: n.Pos(), pos: n.Pos(),
Text: fmt.Sprintf(format, args...) + fmt.Sprintf(" (%s)", j.check), Position: pos,
Text: fmt.Sprintf(format, args...),
Check: j.check,
Checker: j.checker,
Package: pkg,
} }
j.problems = append(j.problems, problem) j.problems = append(j.problems, problem)
return &j.problems[len(j.problems)-1] return &j.problems[len(j.problems)-1]
@ -422,6 +677,31 @@ func IsGenerated(f *ast.File) bool {
return false return false
} }
func Preamble(f *ast.File) string {
cutoff := f.Package
if f.Doc != nil {
cutoff = f.Doc.Pos()
}
var out []string
for _, cmt := range f.Comments {
if cmt.Pos() >= cutoff {
break
}
out = append(out, cmt.Text())
}
return strings.Join(out, "\n")
}
func IsPointerLike(T types.Type) bool {
switch T := T.Underlying().(type) {
case *types.Interface, *types.Chan, *types.Map, *types.Pointer:
return true
case *types.Basic:
return T.Kind() == types.UnsafePointer
}
return false
}
func (j *Job) IsGoVersion(minor int) bool { func (j *Job) IsGoVersion(minor int) bool {
return j.Program.GoVersion >= minor return j.Program.GoVersion >= minor
} }

View file

@ -8,23 +8,70 @@
package lintutil // import "honnef.co/go/tools/lint/lintutil" package lintutil // import "honnef.co/go/tools/lint/lintutil"
import ( import (
"encoding/json"
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"go/build" "go/build"
"go/parser" "go/parser"
"go/token" "go/token"
"go/types"
"io"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"honnef.co/go/tools/lint" "honnef.co/go/tools/lint"
"honnef.co/go/tools/version"
"github.com/kisielk/gotool" "github.com/kisielk/gotool"
"golang.org/x/tools/go/loader" "golang.org/x/tools/go/loader"
) )
type OutputFormatter interface {
Format(p lint.Problem)
}
type TextOutput struct {
w io.Writer
}
func (o TextOutput) Format(p lint.Problem) {
fmt.Fprintf(o.w, "%v: %s\n", relativePositionString(p.Position), p.String())
}
type JSONOutput struct {
w io.Writer
}
func (o JSONOutput) Format(p lint.Problem) {
type location struct {
File string `json:"file"`
Line int `json:"line"`
Column int `json:"column"`
}
jp := struct {
Checker string `json:"checker"`
Code string `json:"code"`
Severity string `json:"severity,omitempty"`
Location location `json:"location"`
Message string `json:"message"`
Ignored bool `json:"ignored"`
}{
p.Checker,
p.Check,
"", // TODO(dh): support severity
location{
p.Position.Filename,
p.Position.Line,
p.Position.Column,
},
p.Text,
p.Ignored,
}
_ = json.NewEncoder(o.w).Encode(jp)
}
func usage(name string, flags *flag.FlagSet) func() { func usage(name string, flags *flag.FlagSet) func() {
return func() { return func() {
fmt.Fprintf(os.Stderr, "Usage of %s:\n", name) fmt.Fprintf(os.Stderr, "Usage of %s:\n", name)
@ -38,13 +85,14 @@ func usage(name string, flags *flag.FlagSet) func() {
} }
type runner struct { type runner struct {
checker lint.Checker checker lint.Checker
tags []string tags []string
ignores []lint.Ignore ignores []lint.Ignore
version int version int
returnIgnored bool
} }
func (runner runner) resolveRelative(importPaths []string) (goFiles bool, err error) { func resolveRelative(importPaths []string, tags []string) (goFiles bool, err error) {
if len(importPaths) == 0 { if len(importPaths) == 0 {
return false, nil return false, nil
} }
@ -57,7 +105,7 @@ func (runner runner) resolveRelative(importPaths []string) (goFiles bool, err er
return false, err return false, err
} }
ctx := build.Default ctx := build.Default
ctx.BuildTags = runner.tags ctx.BuildTags = tags
for i, path := range importPaths { for i, path := range importPaths {
bpkg, err := ctx.Import(path, wd, build.FindOnly) bpkg, err := ctx.Import(path, wd, build.FindOnly)
if err != nil { if err != nil {
@ -80,7 +128,7 @@ func parseIgnore(s string) ([]lint.Ignore, error) {
} }
path := p[0] path := p[0]
checks := strings.Split(p[1], ",") checks := strings.Split(p[1], ",")
out = append(out, lint.Ignore{Pattern: path, Checks: checks}) out = append(out, &lint.GlobIgnore{Pattern: path, Checks: checks})
} }
return out, nil return out, nil
} }
@ -117,6 +165,9 @@ func FlagSet(name string) *flag.FlagSet {
flags.String("tags", "", "List of `build tags`") flags.String("tags", "", "List of `build tags`")
flags.String("ignore", "", "Space separated list of checks to ignore, in the following format: 'import/path/file.go:Check1,Check2,...' Both the import path and file name sections support globbing, e.g. 'os/exec/*_test.go'") flags.String("ignore", "", "Space separated list of checks to ignore, in the following format: 'import/path/file.go:Check1,Check2,...' Both the import path and file name sections support globbing, e.g. 'os/exec/*_test.go'")
flags.Bool("tests", true, "Include tests") flags.Bool("tests", true, "Include tests")
flags.Bool("version", false, "Print version and exit")
flags.Bool("show-ignored", false, "Don't filter ignored problems")
flags.String("f", "text", "Output `format` (valid choices are 'text' and 'json')")
tags := build.Default.ReleaseTags tags := build.Default.ReleaseTags
v := tags[len(tags)-1][2:] v := tags[len(tags)-1][2:]
@ -129,67 +180,105 @@ func FlagSet(name string) *flag.FlagSet {
return flags return flags
} }
func ProcessFlagSet(c lint.Checker, fs *flag.FlagSet) { type CheckerConfig struct {
Checker lint.Checker
ExitNonZero bool
}
func ProcessFlagSet(confs []CheckerConfig, fs *flag.FlagSet) {
tags := fs.Lookup("tags").Value.(flag.Getter).Get().(string) tags := fs.Lookup("tags").Value.(flag.Getter).Get().(string)
ignore := fs.Lookup("ignore").Value.(flag.Getter).Get().(string) ignore := fs.Lookup("ignore").Value.(flag.Getter).Get().(string)
tests := fs.Lookup("tests").Value.(flag.Getter).Get().(bool) tests := fs.Lookup("tests").Value.(flag.Getter).Get().(bool)
version := fs.Lookup("go").Value.(flag.Getter).Get().(int) goVersion := fs.Lookup("go").Value.(flag.Getter).Get().(int)
format := fs.Lookup("f").Value.(flag.Getter).Get().(string)
printVersion := fs.Lookup("version").Value.(flag.Getter).Get().(bool)
showIgnored := fs.Lookup("show-ignored").Value.(flag.Getter).Get().(bool)
ps, lprog, err := Lint(c, fs.Args(), &Options{ if printVersion {
Tags: strings.Fields(tags), version.Print()
LintTests: tests, os.Exit(0)
Ignores: ignore, }
GoVersion: version,
var cs []lint.Checker
for _, conf := range confs {
cs = append(cs, conf.Checker)
}
pss, err := Lint(cs, fs.Args(), &Options{
Tags: strings.Fields(tags),
LintTests: tests,
Ignores: ignore,
GoVersion: goVersion,
ReturnIgnored: showIgnored,
}) })
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr, err) fmt.Fprintln(os.Stderr, err)
os.Exit(1) os.Exit(1)
} }
unclean := false
for _, p := range ps { var ps []lint.Problem
unclean = true for _, p := range pss {
pos := lprog.Fset.Position(p.Position) ps = append(ps, p...)
fmt.Printf("%v: %s\n", relativePositionString(pos), p.Text)
} }
if unclean {
os.Exit(1) var f OutputFormatter
switch format {
case "text":
f = TextOutput{os.Stdout}
case "json":
f = JSONOutput{os.Stdout}
default:
fmt.Fprintf(os.Stderr, "unsupported output format %q\n", format)
os.Exit(2)
}
for _, p := range ps {
f.Format(p)
}
for i, p := range pss {
if len(p) != 0 && confs[i].ExitNonZero {
os.Exit(1)
}
} }
} }
type Options struct { type Options struct {
Tags []string Tags []string
LintTests bool LintTests bool
Ignores string Ignores string
GoVersion int GoVersion int
ReturnIgnored bool
} }
func Lint(c lint.Checker, pkgs []string, opt *Options) ([]lint.Problem, *loader.Program, error) { func Lint(cs []lint.Checker, pkgs []string, opt *Options) ([][]lint.Problem, error) {
// TODO(dh): Instead of returning the loader.Program, we should
// store token.Position instead of token.Pos in lint.Problem.
if opt == nil { if opt == nil {
opt = &Options{} opt = &Options{}
} }
ignores, err := parseIgnore(opt.Ignores) ignores, err := parseIgnore(opt.Ignores)
if err != nil { if err != nil {
return nil, nil, err return nil, err
}
runner := &runner{
checker: c,
tags: opt.Tags,
ignores: ignores,
version: opt.GoVersion,
} }
paths := gotool.ImportPaths(pkgs) paths := gotool.ImportPaths(pkgs)
goFiles, err := runner.resolveRelative(paths) goFiles, err := resolveRelative(paths, opt.Tags)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
ctx := build.Default ctx := build.Default
ctx.BuildTags = runner.tags ctx.BuildTags = opt.Tags
hadError := false
conf := &loader.Config{ conf := &loader.Config{
Build: &ctx, Build: &ctx,
ParserMode: parser.ParseComments, ParserMode: parser.ParseComments,
ImportPkgs: map[string]bool{}, ImportPkgs: map[string]bool{},
TypeChecker: types.Config{
Error: func(err error) {
// Only print the first error found
if hadError {
return
}
hadError = true
fmt.Fprintln(os.Stderr, err)
},
},
} }
if goFiles { if goFiles {
conf.CreateFromFilenames("adhoc", paths...) conf.CreateFromFilenames("adhoc", paths...)
@ -200,9 +289,21 @@ func Lint(c lint.Checker, pkgs []string, opt *Options) ([]lint.Problem, *loader.
} }
lprog, err := conf.Load() lprog, err := conf.Load()
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
return runner.lint(lprog), lprog, nil
var problems [][]lint.Problem
for _, c := range cs {
runner := &runner{
checker: c,
tags: opt.Tags,
ignores: ignores,
version: opt.GoVersion,
returnIgnored: opt.ReturnIgnored,
}
problems = append(problems, runner.lint(lprog, conf))
}
return problems, nil
} }
func shortPath(path string) string { func shortPath(path string) string {
@ -230,18 +331,19 @@ func relativePositionString(pos token.Position) string {
return s return s
} }
func ProcessArgs(name string, c lint.Checker, args []string) { func ProcessArgs(name string, cs []CheckerConfig, args []string) {
flags := FlagSet(name) flags := FlagSet(name)
flags.Parse(args) flags.Parse(args)
ProcessFlagSet(c, flags) ProcessFlagSet(cs, flags)
} }
func (runner *runner) lint(lprog *loader.Program) []lint.Problem { func (runner *runner) lint(lprog *loader.Program, conf *loader.Config) []lint.Problem {
l := &lint.Linter{ l := &lint.Linter{
Checker: runner.checker, Checker: runner.checker,
Ignores: runner.ignores, Ignores: runner.ignores,
GoVersion: runner.version, GoVersion: runner.version,
ReturnIgnored: runner.returnIgnored,
} }
return l.Lint(lprog) return l.Lint(lprog, conf)
} }

View file

@ -91,7 +91,7 @@ func TestAll(t *testing.T, c lint.Checker, dir string) {
for version, fis := range files { for version, fis := range files {
l := &lint.Linter{Checker: c, GoVersion: version} l := &lint.Linter{Checker: c, GoVersion: version}
res := l.Lint(lprog) res := l.Lint(lprog, conf)
for _, fi := range fis { for _, fi := range fis {
name := fi.Name() name := fi.Name()
src := sources[name] src := sources[name]
@ -101,8 +101,7 @@ func TestAll(t *testing.T, c lint.Checker, dir string) {
for _, in := range ins { for _, in := range ins {
ok := false ok := false
for i, p := range res { for i, p := range res {
pos := lprog.Fset.Position(p.Position) if p.Position.Line != in.Line || filepath.Base(p.Position.Filename) != name {
if pos.Line != in.Line || filepath.Base(pos.Filename) != name {
continue continue
} }
if in.Match.MatchString(p.Text) { if in.Match.MatchString(p.Text) {
@ -121,11 +120,10 @@ func TestAll(t *testing.T, c lint.Checker, dir string) {
} }
} }
for _, p := range res { for _, p := range res {
pos := lprog.Fset.Position(p.Position) name := filepath.Base(p.Position.Filename)
name := filepath.Base(pos.Filename)
for _, fi := range fis { for _, fi := range fis {
if name == fi.Name() { if name == fi.Name() {
t.Errorf("Unexpected problem at %s: %v", pos, p.Text) t.Errorf("Unexpected problem at %s: %v", p.Position, p.Text)
break break
} }
} }
@ -149,7 +147,7 @@ func parseInstructions(t *testing.T, filename string, src []byte) []instruction
} }
var ins []instruction var ins []instruction
for _, cg := range f.Comments { for _, cg := range f.Comments {
ln := fset.Position(cg.Pos()).Line ln := fset.PositionFor(cg.Pos(), false).Line
raw := cg.Text() raw := cg.Text()
for _, line := range strings.Split(raw, "\n") { for _, line := range strings.Split(raw, "\n") {
if line == "" || strings.HasPrefix(line, "#") { if line == "" || strings.HasPrefix(line, "#") {

View file

@ -1,20 +0,0 @@
Copyright (c) 2016 Dominik Honnef
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -29,6 +29,9 @@ func NewChecker() *Checker {
} }
} }
func (*Checker) Name() string { return "gosimple" }
func (*Checker) Prefix() string { return "S" }
func (c *Checker) Init(prog *lint.Program) { func (c *Checker) Init(prog *lint.Program) {
c.nodeFns = lint.NodeFns(prog.Packages) c.nodeFns = lint.NodeFns(prog.Packages)
} }
@ -61,7 +64,7 @@ func (c *Checker) Funcs() map[string]lint.Func {
"S1023": c.LintRedundantBreak, "S1023": c.LintRedundantBreak,
"S1024": c.LintTimeUntil, "S1024": c.LintTimeUntil,
"S1025": c.LintRedundantSprintf, "S1025": c.LintRedundantSprintf,
"S1026": c.LintStringCopy, "S1026": nil,
"S1027": nil, "S1027": nil,
"S1028": c.LintErrorsNewSprintf, "S1028": c.LintErrorsNewSprintf,
"S1029": c.LintRangeStringRunes, "S1029": c.LintRangeStringRunes,
@ -1022,7 +1025,9 @@ func (c *Checker) LintUnnecessaryBlank(j *lint.Job) {
fn := func(node ast.Node) bool { fn := func(node ast.Node) bool {
fn1(node) fn1(node)
fn2(node) fn2(node)
fn3(node) if j.IsGoVersion(4) {
fn3(node)
}
return true return true
} }
for _, f := range c.filterGenerated(j.Program.Files) { for _, f := range c.filterGenerated(j.Program.Files) {
@ -1702,81 +1707,6 @@ func (c *Checker) LintRedundantSprintf(j *lint.Job) {
} }
} }
func (c *Checker) LintStringCopy(j *lint.Job) {
emptyStringLit := func(e ast.Expr) bool {
bl, ok := e.(*ast.BasicLit)
return ok && bl.Value == `""`
}
fn := func(node ast.Node) bool {
switch x := node.(type) {
case *ast.BinaryExpr: // "" + s, s + ""
if x.Op != token.ADD {
break
}
l1 := j.Program.Prog.Fset.Position(x.X.Pos()).Line
l2 := j.Program.Prog.Fset.Position(x.Y.Pos()).Line
if l1 != l2 {
break
}
var want ast.Expr
switch {
case emptyStringLit(x.X):
want = x.Y
case emptyStringLit(x.Y):
want = x.X
default:
return true
}
j.Errorf(x, "should use %s instead of %s",
j.Render(want), j.Render(x))
case *ast.CallExpr:
if j.IsCallToAST(x, "fmt.Sprint") && len(x.Args) == 1 {
// fmt.Sprint(x)
argT := j.Program.Info.TypeOf(x.Args[0])
bt, ok := argT.Underlying().(*types.Basic)
if !ok || bt.Kind() != types.String {
return true
}
if c.Implements(j, argT, "fmt.Stringer") || c.Implements(j, argT, "error") {
return true
}
j.Errorf(x, "should use %s instead of %s", j.Render(x.Args[0]), j.Render(x))
return true
}
// string([]byte(s))
bt, ok := j.Program.Info.TypeOf(x.Fun).(*types.Basic)
if !ok || bt.Kind() != types.String {
break
}
nested, ok := x.Args[0].(*ast.CallExpr)
if !ok {
break
}
st, ok := j.Program.Info.TypeOf(nested.Fun).(*types.Slice)
if !ok {
break
}
et, ok := st.Elem().(*types.Basic)
if !ok || et.Kind() != types.Byte {
break
}
xt, ok := j.Program.Info.TypeOf(nested.Args[0]).(*types.Basic)
if !ok || xt.Kind() != types.String {
break
}
j.Errorf(x, "should use %s instead of %s",
j.Render(nested.Args[0]), j.Render(x))
}
return true
}
for _, f := range c.filterGenerated(j.Program.Files) {
ast.Inspect(f, fn)
}
}
func (c *Checker) LintErrorsNewSprintf(j *lint.Job) { func (c *Checker) LintErrorsNewSprintf(j *lint.Job) {
fn := func(node ast.Node) bool { fn := func(node ast.Node) bool {
if !j.IsCallToAST(node, "errors.New") { if !j.IsCallToAST(node, "errors.New") {

View file

@ -1,20 +0,0 @@
Copyright (c) 2016 Dominik Honnef
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -1,20 +0,0 @@
Copyright (c) 2016 Dominik Honnef
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -0,0 +1,21 @@
package staticcheck
import (
"go/ast"
"strings"
"honnef.co/go/tools/lint"
)
func buildTags(f *ast.File) [][]string {
var out [][]string
for _, line := range strings.Split(lint.Preamble(f), "\n") {
if !strings.HasPrefix(line, "+build ") {
continue
}
line = strings.TrimSpace(strings.TrimPrefix(line, "+build "))
fields := strings.Fields(line)
out = append(out, fields)
}
return out
}

View file

@ -4,20 +4,20 @@ package staticcheck // import "honnef.co/go/tools/staticcheck"
import ( import (
"fmt" "fmt"
"go/ast" "go/ast"
"go/build"
"go/constant" "go/constant"
"go/token" "go/token"
"go/types" "go/types"
htmltemplate "html/template" htmltemplate "html/template"
"net/http" "net/http"
"regexp" "regexp"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
texttemplate "text/template" texttemplate "text/template"
"honnef.co/go/tools/deprecated"
"honnef.co/go/tools/functions" "honnef.co/go/tools/functions"
"honnef.co/go/tools/gcsizes"
"honnef.co/go/tools/internal/sharedcheck" "honnef.co/go/tools/internal/sharedcheck"
"honnef.co/go/tools/lint" "honnef.co/go/tools/lint"
"honnef.co/go/tools/ssa" "honnef.co/go/tools/ssa"
@ -111,14 +111,12 @@ var (
}, },
} }
checkSyncPoolSizeRules = map[string]CallCheck{ checkSyncPoolValueRules = map[string]CallCheck{
"(*sync.Pool).Put": func(call *Call) { "(*sync.Pool).Put": func(call *Call) {
// TODO(dh): allow users to pass in a custom build environment
sizes := gcsizes.ForArch(build.Default.GOARCH)
arg := call.Args[0] arg := call.Args[0]
typ := arg.Value.Value.Type() typ := arg.Value.Value.Type()
if !types.IsInterface(typ) && sizes.Sizeof(typ) > sizes.WordSize { if !lint.IsPointerLike(typ) {
arg.Invalid("argument should be one word large or less to avoid allocations") arg.Invalid("argument should be pointer-like to avoid allocations")
} }
}, },
} }
@ -209,6 +207,9 @@ func NewChecker() *Checker {
return &Checker{} return &Checker{}
} }
func (*Checker) Name() string { return "staticcheck" }
func (*Checker) Prefix() string { return "SA" }
func (c *Checker) Funcs() map[string]lint.Func { func (c *Checker) Funcs() map[string]lint.Func {
return map[string]lint.Func{ return map[string]lint.Func{
"SA1000": c.callChecker(checkRegexpRules), "SA1000": c.callChecker(checkRegexpRules),
@ -265,6 +266,7 @@ func (c *Checker) Funcs() map[string]lint.Func {
"SA4016": c.CheckSillyBitwiseOps, "SA4016": c.CheckSillyBitwiseOps,
"SA4017": c.CheckPureFunctions, "SA4017": c.CheckPureFunctions,
"SA4018": c.CheckSelfAssignment, "SA4018": c.CheckSelfAssignment,
"SA4019": c.CheckDuplicateBuildConstraints,
"SA5000": c.CheckNilMaps, "SA5000": c.CheckNilMaps,
"SA5001": c.CheckEarlyDefer, "SA5001": c.CheckEarlyDefer,
@ -277,7 +279,7 @@ func (c *Checker) Funcs() map[string]lint.Func {
"SA6000": c.callChecker(checkRegexpMatchLoopRules), "SA6000": c.callChecker(checkRegexpMatchLoopRules),
"SA6001": c.CheckMapBytesKey, "SA6001": c.CheckMapBytesKey,
"SA6002": c.callChecker(checkSyncPoolSizeRules), "SA6002": c.callChecker(checkSyncPoolValueRules),
"SA6003": c.CheckRangeStringRunes, "SA6003": c.CheckRangeStringRunes,
"SA6004": nil, "SA6004": nil,
@ -301,36 +303,62 @@ func (c *Checker) filterGenerated(files []*ast.File) []*ast.File {
return out return out
} }
func (c *Checker) Init(prog *lint.Program) { func (c *Checker) deprecateObject(m map[types.Object]string, prog *lint.Program, obj types.Object) {
c.funcDescs = functions.NewDescriptions(prog.SSA) if obj.Pkg() == nil {
c.deprecatedObjs = map[types.Object]string{} return
c.nodeFns = map[ast.Node]*ssa.Function{}
for _, fn := range prog.AllFunctions {
if fn.Blocks != nil {
applyStdlibKnowledge(fn)
ssa.OptimizeBlocks(fn)
}
} }
c.nodeFns = lint.NodeFns(prog.Packages) f := prog.File(obj)
if f == nil {
return
}
msg := c.deprecationMessage(f, prog.Prog.Fset, obj)
if msg != "" {
m[obj] = msg
}
}
deprecated := []map[types.Object]string{} func (c *Checker) Init(prog *lint.Program) {
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
for _, pkginfo := range prog.Prog.AllPackages { wg.Add(3)
pkginfo := pkginfo go func() {
scope := pkginfo.Pkg.Scope() c.funcDescs = functions.NewDescriptions(prog.SSA)
names := scope.Names() for _, fn := range prog.AllFunctions {
wg.Add(1) if fn.Blocks != nil {
applyStdlibKnowledge(fn)
ssa.OptimizeBlocks(fn)
}
}
wg.Done()
}()
m := map[types.Object]string{} go func() {
deprecated = append(deprecated, m) c.nodeFns = lint.NodeFns(prog.Packages)
go func(m map[types.Object]string) { wg.Done()
for _, name := range names { }()
obj := scope.Lookup(name)
msg := c.deprecationMessage(pkginfo.Files, prog.SSA.Fset, obj) go func() {
if msg != "" { c.deprecatedObjs = map[types.Object]string{}
m[obj] = msg for _, ssapkg := range prog.SSA.AllPackages() {
ssapkg := ssapkg
for _, member := range ssapkg.Members {
obj := member.Object()
if obj == nil {
continue
}
c.deprecateObject(c.deprecatedObjs, prog, obj)
if typ, ok := obj.Type().(*types.Named); ok {
for i := 0; i < typ.NumMethods(); i++ {
meth := typ.Method(i)
c.deprecateObject(c.deprecatedObjs, prog, meth)
}
if iface, ok := typ.Underlying().(*types.Interface); ok {
for i := 0; i < iface.NumExplicitMethods(); i++ {
meth := iface.ExplicitMethod(i)
c.deprecateObject(c.deprecatedObjs, prog, meth)
}
}
} }
if typ, ok := obj.Type().Underlying().(*types.Struct); ok { if typ, ok := obj.Type().Underlying().(*types.Struct); ok {
n := typ.NumFields() n := typ.NumFields()
@ -338,51 +366,20 @@ func (c *Checker) Init(prog *lint.Program) {
// FIXME(dh): This code will not find deprecated // FIXME(dh): This code will not find deprecated
// fields in anonymous structs. // fields in anonymous structs.
field := typ.Field(i) field := typ.Field(i)
msg := c.deprecationMessage(pkginfo.Files, prog.SSA.Fset, field) c.deprecateObject(c.deprecatedObjs, prog, field)
if msg != "" {
m[field] = msg
}
} }
} }
} }
wg.Done() }
}(m) wg.Done()
} }()
wg.Wait() wg.Wait()
for _, m := range deprecated {
for k, v := range m {
c.deprecatedObjs[k] = v
}
}
} }
// TODO(adonovan): make this a method: func (*token.File) Contains(token.Pos) func (c *Checker) deprecationMessage(file *ast.File, fset *token.FileSet, obj types.Object) (message string) {
func tokenFileContainsPos(f *token.File, pos token.Pos) bool { pos := obj.Pos()
p := int(pos) path, _ := astutil.PathEnclosingInterval(file, pos, pos)
base := f.Base()
return base <= p && p < base+f.Size()
}
func pathEnclosingInterval(files []*ast.File, fset *token.FileSet, start, end token.Pos) (path []ast.Node, exact bool) {
for _, f := range files {
if f.Pos() == token.NoPos {
// This can happen if the parser saw
// too many errors and bailed out.
// (Use parser.AllErrors to prevent that.)
continue
}
if !tokenFileContainsPos(fset.File(f.Pos()), start) {
continue
}
if path, exact := astutil.PathEnclosingInterval(f, start, end); path != nil {
return path, exact
}
}
return nil, false
}
func (c *Checker) deprecationMessage(files []*ast.File, fset *token.FileSet, obj types.Object) (message string) {
path, _ := pathEnclosingInterval(files, fset, obj.Pos(), obj.Pos())
if len(path) <= 2 { if len(path) <= 2 {
return "" return ""
} }
@ -2065,7 +2062,7 @@ func (c *Checker) CheckCyclicFinalizer(j *lint.Job) {
} }
for _, b := range mc.Bindings { for _, b := range mc.Bindings {
if b == v { if b == v {
pos := j.Program.SSA.Fset.Position(mc.Fn.Pos()) pos := j.Program.DisplayPosition(mc.Fn.Pos())
j.Errorf(edge.Site, "the finalizer closes over the object, preventing the finalizer from ever running (at %s)", pos) j.Errorf(edge.Site, "the finalizer closes over the object, preventing the finalizer from ever running (at %s)", pos)
} }
} }
@ -2166,6 +2163,11 @@ func (c *Checker) CheckInfiniteRecursion(j *lint.Job) {
if edge.Callee != node { if edge.Callee != node {
continue continue
} }
if _, ok := edge.Site.(*ssa.Go); ok {
// Recursively spawning goroutines doesn't consume
// stack space infinitely, so don't flag it.
continue
}
block := edge.Site.Block() block := edge.Site.Block()
canReturn := false canReturn := false
@ -2437,7 +2439,7 @@ fnLoop:
if callee == nil { if callee == nil {
continue continue
} }
if c.funcDescs.Get(callee).Pure { if c.funcDescs.Get(callee).Pure && !c.funcDescs.Get(callee).Stub {
j.Errorf(ins, "%s is a pure function but its return value is ignored", callee.Name()) j.Errorf(ins, "%s is a pure function but its return value is ignored", callee.Name())
continue continue
} }
@ -2446,22 +2448,6 @@ fnLoop:
} }
} }
func enclosingFunction(j *lint.Job, node ast.Node) *ast.FuncDecl {
f := j.File(node)
path, _ := astutil.PathEnclosingInterval(f, node.Pos(), node.Pos())
for _, e := range path {
fn, ok := e.(*ast.FuncDecl)
if !ok {
continue
}
if fn.Name == nil {
continue
}
return fn
}
return nil
}
func (c *Checker) isDeprecated(j *lint.Job, ident *ast.Ident) (bool, string) { func (c *Checker) isDeprecated(j *lint.Job, ident *ast.Ident) (bool, string) {
obj := j.Program.Info.ObjectOf(ident) obj := j.Program.Info.ObjectOf(ident)
if obj.Pkg() == nil { if obj.Pkg() == nil {
@ -2471,19 +2457,34 @@ func (c *Checker) isDeprecated(j *lint.Job, ident *ast.Ident) (bool, string) {
return alt != "", alt return alt != "", alt
} }
func selectorName(j *lint.Job, expr *ast.SelectorExpr) string {
sel := j.Program.Info.Selections[expr]
if sel == nil {
if x, ok := expr.X.(*ast.Ident); ok {
return fmt.Sprintf("%s.%s", x.Name, expr.Sel.Name)
}
panic(fmt.Sprintf("unsupported selector: %v", expr))
}
return fmt.Sprintf("(%s).%s", sel.Recv(), sel.Obj().Name())
}
func (c *Checker) enclosingFunc(sel *ast.SelectorExpr) *ssa.Function {
fn := c.nodeFns[sel]
if fn == nil {
return nil
}
for fn.Parent() != nil {
fn = fn.Parent()
}
return fn
}
func (c *Checker) CheckDeprecated(j *lint.Job) { func (c *Checker) CheckDeprecated(j *lint.Job) {
fn := func(node ast.Node) bool { fn := func(node ast.Node) bool {
sel, ok := node.(*ast.SelectorExpr) sel, ok := node.(*ast.SelectorExpr)
if !ok { if !ok {
return true return true
} }
if fn := enclosingFunction(j, sel); fn != nil {
if ok, _ := c.isDeprecated(j, fn.Name); ok {
// functions that are deprecated may use deprecated
// symbols
return true
}
}
obj := j.Program.Info.ObjectOf(sel.Sel) obj := j.Program.Info.ObjectOf(sel.Sel)
if obj.Pkg() == nil { if obj.Pkg() == nil {
@ -2495,6 +2496,24 @@ func (c *Checker) CheckDeprecated(j *lint.Job) {
return true return true
} }
if ok, alt := c.isDeprecated(j, sel.Sel); ok { if ok, alt := c.isDeprecated(j, sel.Sel); ok {
// Look for the first available alternative, not the first
// version something was deprecated in. If a function was
// deprecated in Go 1.6, an alternative has been available
// already in 1.0, and we're targetting 1.2, it still
// makes sense to use the alternative from 1.0, to be
// future-proof.
minVersion := deprecated.Stdlib[selectorName(j, sel)].AlternativeAvailableSince
if !j.IsGoVersion(minVersion) {
return true
}
if fn := c.enclosingFunc(sel); fn != nil {
if _, ok := c.deprecatedObjs[fn.Object()]; ok {
// functions that are deprecated may use deprecated
// symbols
return true
}
}
j.Errorf(sel, "%s is deprecated: %s", j.Render(sel), alt) j.Errorf(sel, "%s is deprecated: %s", j.Render(sel), alt)
return true return true
} }
@ -2784,3 +2803,39 @@ func (c *Checker) CheckSelfAssignment(j *lint.Job) {
ast.Inspect(f, fn) ast.Inspect(f, fn)
} }
} }
func buildTagsIdentical(s1, s2 []string) bool {
if len(s1) != len(s2) {
return false
}
s1s := make([]string, len(s1))
copy(s1s, s1)
sort.Strings(s1s)
s2s := make([]string, len(s2))
copy(s2s, s2)
sort.Strings(s2s)
for i, s := range s1s {
if s != s2s[i] {
return false
}
}
return true
}
func (c *Checker) CheckDuplicateBuildConstraints(job *lint.Job) {
for _, f := range c.filterGenerated(job.Program.Files) {
constraints := buildTags(f)
for i, constraint1 := range constraints {
for j, constraint2 := range constraints {
if i >= j {
continue
}
if buildTagsIdentical(constraint1, constraint2) {
job.Errorf(f, "identical build constraints %q and %q",
strings.Join(constraint1, " "),
strings.Join(constraint2, " "))
}
}
}
}
}

View file

@ -0,0 +1,22 @@
package structlayout
import "fmt"
type Field struct {
Name string `json:"name"`
Type string `json:"type"`
Start int64 `json:"start"`
End int64 `json:"end"`
Size int64 `json:"size"`
Align int64 `json:"align"`
IsPadding bool `json:"is_padding"`
}
func (f Field) String() string {
if f.IsPadding {
return fmt.Sprintf("%s: %d-%d (size %d, align %d)",
"padding", f.Start, f.End, f.Size, f.Align)
}
return fmt.Sprintf("%s %s: %d-%d (size %d, align %d)",
f.Name, f.Type, f.Start, f.End, f.Size, f.Align)
}

View file

@ -1,20 +0,0 @@
Copyright (c) 2016 Dominik Honnef
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -26,6 +26,9 @@ type LintChecker struct {
c *Checker c *Checker
} }
func (*LintChecker) Name() string { return "unused" }
func (*LintChecker) Prefix() string { return "U" }
func (l *LintChecker) Init(*lint.Program) {} func (l *LintChecker) Init(*lint.Program) {}
func (l *LintChecker) Funcs() map[string]lint.Func { func (l *LintChecker) Funcs() map[string]lint.Func {
return map[string]lint.Func{ return map[string]lint.Func{
@ -275,6 +278,51 @@ func (c *Checker) Check(lprog *loader.Program) []Unused {
return unused return unused
} }
// isNoCopyType reports whether a type represents the NoCopy sentinel
// type. The NoCopy type is a named struct with no fields and exactly
// one method `func Lock()` that is empty.
//
// FIXME(dh): currently we're not checking that the function body is
// empty.
func isNoCopyType(typ types.Type) bool {
st, ok := typ.Underlying().(*types.Struct)
if !ok {
return false
}
if st.NumFields() != 0 {
return false
}
named, ok := typ.(*types.Named)
if !ok {
return false
}
if named.NumMethods() != 1 {
return false
}
meth := named.Method(0)
if meth.Name() != "Lock" {
return false
}
sig := meth.Type().(*types.Signature)
if sig.Params().Len() != 0 || sig.Results().Len() != 0 {
return false
}
return true
}
func (c *Checker) useNoCopyFields(typ types.Type) {
if st, ok := typ.Underlying().(*types.Struct); ok {
n := st.NumFields()
for i := 0; i < n; i++ {
field := st.Field(i)
if isNoCopyType(field.Type()) {
c.graph.markUsedBy(field, typ)
}
}
}
}
func (c *Checker) useExportedFields(typ types.Type) { func (c *Checker) useExportedFields(typ types.Type) {
if st, ok := typ.Underlying().(*types.Struct); ok { if st, ok := typ.Underlying().(*types.Struct); ok {
n := st.NumFields() n := st.NumFields()
@ -485,6 +533,7 @@ func (c *Checker) processTypes(pkg *loader.PackageInfo) {
interfaces = append(interfaces, obj) interfaces = append(interfaces, obj)
} }
case *types.Struct: case *types.Struct:
c.useNoCopyFields(obj)
if pkg.Pkg.Name() != "main" && !c.WholeProgram { if pkg.Pkg.Name() != "main" && !c.WholeProgram {
c.useExportedFields(obj) c.useExportedFields(obj)
} }

View file

@ -0,0 +1,17 @@
package version
import (
"fmt"
"os"
"path/filepath"
)
const Version = "2017.2"
func Print() {
if Version == "devel" {
fmt.Printf("%s (no version)\n", filepath.Base(os.Args[0]))
} else {
fmt.Printf("%s %s\n", filepath.Base(os.Args[0]), Version)
}
}

View file

@ -109,7 +109,7 @@
"importpath": "github.com/opennota/check", "importpath": "github.com/opennota/check",
"repository": "https://github.com/opennota/check", "repository": "https://github.com/opennota/check",
"vcs": "git", "vcs": "git",
"revision": "11e2eec79ec4f789607e3efbf405cdca2504d4cb", "revision": "0ec6d92e97559edf84eebe7c51b069953d4b522c",
"branch": "master", "branch": "master",
"notests": true "notests": true
}, },
@ -309,120 +309,20 @@
"notests": true "notests": true
}, },
{ {
"importpath": "honnef.co/go/tools/callgraph", "importpath": "golang.org/x/tools/refactor/importgraph",
"repository": "https://github.com/dominikh/go-tools", "repository": "https://go.googlesource.com/tools",
"vcs": "git", "vcs": "git",
"revision": "49f44f893d933fd08cd7d67d65ccefa5d7c23329", "revision": "9c477bae194915bfd4bc8c314e90e28b9ec1c831",
"branch": "master", "branch": "master",
"path": "callgraph", "path": "/refactor/importgraph",
"notests": true "notests": true
}, },
{ {
"importpath": "honnef.co/go/tools/cmd/gosimple", "importpath": "honnef.co/go/tools",
"repository": "https://github.com/dominikh/go-tools", "repository": "https://github.com/dominikh/go-tools",
"vcs": "git", "vcs": "git",
"revision": "49f44f893d933fd08cd7d67d65ccefa5d7c23329", "revision": "50914165a1ae448f1608c6c325d052313396182e",
"branch": "master", "branch": "HEAD",
"path": "/cmd/gosimple",
"notests": true
},
{
"importpath": "honnef.co/go/tools/cmd/megacheck",
"repository": "https://github.com/dominikh/go-tools",
"vcs": "git",
"revision": "49f44f893d933fd08cd7d67d65ccefa5d7c23329",
"branch": "master",
"path": "/cmd/megacheck",
"notests": true
},
{
"importpath": "honnef.co/go/tools/cmd/staticcheck",
"repository": "https://github.com/dominikh/go-tools",
"vcs": "git",
"revision": "49f44f893d933fd08cd7d67d65ccefa5d7c23329",
"branch": "master",
"path": "/cmd/staticcheck",
"notests": true
},
{
"importpath": "honnef.co/go/tools/cmd/unused",
"repository": "https://github.com/dominikh/go-tools",
"vcs": "git",
"revision": "49f44f893d933fd08cd7d67d65ccefa5d7c23329",
"branch": "master",
"path": "/cmd/unused",
"notests": true
},
{
"importpath": "honnef.co/go/tools/functions",
"repository": "https://github.com/dominikh/go-tools",
"vcs": "git",
"revision": "49f44f893d933fd08cd7d67d65ccefa5d7c23329",
"branch": "master",
"path": "functions",
"notests": true
},
{
"importpath": "honnef.co/go/tools/gcsizes",
"repository": "https://github.com/dominikh/go-tools",
"vcs": "git",
"revision": "49f44f893d933fd08cd7d67d65ccefa5d7c23329",
"branch": "master",
"path": "gcsizes",
"notests": true
},
{
"importpath": "honnef.co/go/tools/internal/sharedcheck",
"repository": "https://github.com/dominikh/go-tools",
"vcs": "git",
"revision": "49f44f893d933fd08cd7d67d65ccefa5d7c23329",
"branch": "master",
"path": "/internal/sharedcheck",
"notests": true
},
{
"importpath": "honnef.co/go/tools/lint",
"repository": "https://github.com/dominikh/go-tools",
"vcs": "git",
"revision": "49f44f893d933fd08cd7d67d65ccefa5d7c23329",
"branch": "master",
"path": "lint",
"notests": true
},
{
"importpath": "honnef.co/go/tools/simple",
"repository": "https://github.com/dominikh/go-tools",
"vcs": "git",
"revision": "49f44f893d933fd08cd7d67d65ccefa5d7c23329",
"branch": "master",
"path": "simple",
"notests": true
},
{
"importpath": "honnef.co/go/tools/ssa",
"repository": "https://github.com/dominikh/go-tools",
"vcs": "git",
"revision": "49f44f893d933fd08cd7d67d65ccefa5d7c23329",
"branch": "master",
"path": "ssa",
"notests": true
},
{
"importpath": "honnef.co/go/tools/staticcheck",
"repository": "https://github.com/dominikh/go-tools",
"vcs": "git",
"revision": "49f44f893d933fd08cd7d67d65ccefa5d7c23329",
"branch": "master",
"path": "staticcheck",
"notests": true
},
{
"importpath": "honnef.co/go/tools/unused",
"repository": "https://github.com/dominikh/go-tools",
"vcs": "git",
"revision": "49f44f893d933fd08cd7d67d65ccefa5d7c23329",
"branch": "master",
"path": "unused",
"notests": true "notests": true
}, },
{ {
@ -450,4 +350,4 @@
"notests": true "notests": true
} }
] ]
} }

View file

@ -25,7 +25,7 @@ func AggregateIssueChan(issues chan *Issue) chan *Issue {
go func() { go func() {
for issue := range issues { for issue := range issues {
key := issueKey{ key := issueKey{
path: issue.Path, path: issue.Path.String(),
line: issue.Line, line: issue.Line,
col: issue.Col, col: issue.Col,
message: issue.Message, message: issue.Message,

View file

@ -33,14 +33,13 @@ func outputToCheckstyle(issues chan *Issue) int {
} }
status := 0 status := 0
for issue := range issues { for issue := range issues {
if lastFile != nil && lastFile.Name != issue.Path { path := issue.Path.Relative()
if lastFile != nil && lastFile.Name != path {
out.Files = append(out.Files, lastFile) out.Files = append(out.Files, lastFile)
lastFile = nil lastFile = nil
} }
if lastFile == nil { if lastFile == nil {
lastFile = &checkstyleFile{ lastFile = &checkstyleFile{Name: path}
Name: issue.Path,
}
} }
if config.Errors && issue.Severity != Error { if config.Errors && issue.Severity != Error {

View file

@ -38,6 +38,7 @@ type Config struct { // nolint: maligned
Vendor bool Vendor bool
Cyclo int Cyclo int
LineLength int LineLength int
MisspellLocale string
MinConfidence float64 MinConfidence float64
MinOccurrences int MinOccurrences int
MinConstLength int MinConstLength int
@ -128,6 +129,7 @@ var config = &Config{
Concurrency: runtime.NumCPU(), Concurrency: runtime.NumCPU(),
Cyclo: 10, Cyclo: 10,
LineLength: 80, LineLength: 80,
MisspellLocale: "",
MinConfidence: 0.8, MinConfidence: 0.8,
MinOccurrences: 3, MinOccurrences: 3,
MinConstLength: 3, MinConstLength: 3,

View file

@ -5,6 +5,7 @@ import (
"go/ast" "go/ast"
"go/parser" "go/parser"
"go/token" "go/token"
"os"
"sort" "sort"
"strings" "strings"
"sync" "sync"
@ -67,11 +68,12 @@ func newDirectiveParser() *directiveParser {
// IsIgnored returns true if the given linter issue is ignored by a linter directive. // IsIgnored returns true if the given linter issue is ignored by a linter directive.
func (d *directiveParser) IsIgnored(issue *Issue) bool { func (d *directiveParser) IsIgnored(issue *Issue) bool {
d.lock.Lock() d.lock.Lock()
ranges, ok := d.files[issue.Path] path := issue.Path.Relative()
ranges, ok := d.files[path]
if !ok { if !ok {
ranges = d.parseFile(issue.Path) ranges = d.parseFile(path)
sort.Sort(ranges) sort.Sort(ranges)
d.files[issue.Path] = ranges d.files[path] = ranges
} }
d.lock.Unlock() d.lock.Unlock()
for _, r := range ranges { for _, r := range ranges {
@ -204,10 +206,16 @@ func filterIssuesViaDirectives(directives *directiveParser, issues chan *Issue)
func warnOnUnusedDirective(directives *directiveParser) []*Issue { func warnOnUnusedDirective(directives *directiveParser) []*Issue {
out := []*Issue{} out := []*Issue{}
cwd, err := os.Getwd()
if err != nil {
warning("failed to get working directory %s", err)
}
for path, ranges := range directives.Unmatched() { for path, ranges := range directives.Unmatched() {
for _, ignore := range ranges { for _, ignore := range ranges {
issue, _ := NewIssue("nolint", config.formatTemplate) issue, _ := NewIssue("nolint", config.formatTemplate)
issue.Path = path issue.Path = newIssuePath(cwd, path)
issue.Line = ignore.start issue.Line = ignore.start
issue.Col = ignore.col issue.Col = ignore.col
issue.Message = "nolint directive did not match any issue" issue.Message = "nolint directive did not match any issue"

View file

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"os" "os"
"os/exec" "os/exec"
"path/filepath"
"reflect" "reflect"
"regexp" "regexp"
"strconv" "strconv"
@ -82,6 +81,7 @@ func runLinters(linters map[string]*Linter, paths []string, concurrency int, exc
"duplthreshold": fmt.Sprintf("%d", config.DuplThreshold), "duplthreshold": fmt.Sprintf("%d", config.DuplThreshold),
"mincyclo": fmt.Sprintf("%d", config.Cyclo), "mincyclo": fmt.Sprintf("%d", config.Cyclo),
"maxlinelength": fmt.Sprintf("%d", config.LineLength), "maxlinelength": fmt.Sprintf("%d", config.LineLength),
"misspelllocale": fmt.Sprintf("%s", config.MisspellLocale),
"min_confidence": fmt.Sprintf("%f", config.MinConfidence), "min_confidence": fmt.Sprintf("%f", config.MinConfidence),
"min_occurrences": fmt.Sprintf("%d", config.MinOccurrences), "min_occurrences": fmt.Sprintf("%d", config.MinOccurrences),
"min_const_length": fmt.Sprintf("%d", config.MinConstLength), "min_const_length": fmt.Sprintf("%d", config.MinConstLength),
@ -237,8 +237,10 @@ func processOutput(dbg debugFunction, state *linterState, out []byte) {
} }
switch name { switch name {
case "path": case "path":
issue.Path = relativePath(cwd, part) issue.Path, err = newIssuePathFromAbsPath(cwd, part)
if err != nil {
warning("failed to make %s a relative path: %s", part, err)
}
case "line": case "line":
n, err := strconv.ParseInt(part, 10, 32) n, err := strconv.ParseInt(part, 10, 32)
kingpin.FatalIfError(err, "line matched invalid integer") kingpin.FatalIfError(err, "line matched invalid integer")
@ -273,37 +275,6 @@ func processOutput(dbg debugFunction, state *linterState, out []byte) {
} }
} }
func relativePath(root, path string) string {
fallback := path
root = resolvePath(root)
path = resolvePath(path)
var err error
path, err = filepath.Rel(root, path)
if err != nil {
warning("failed to make %s a relative path: %s", fallback, err)
return fallback
}
return path
}
func resolvePath(path string) string {
var err error
fallback := path
if !filepath.IsAbs(path) {
path, err = filepath.Abs(path)
if err != nil {
warning("failed to make %s an absolute path: %s", fallback, err)
return fallback
}
}
path, err = filepath.EvalSymlinks(path)
if err != nil {
warning("failed to resolve symlinks in %s: %s", fallback, err)
return fallback
}
return path
}
func maybeSortIssues(issues chan *Issue) chan *Issue { func maybeSortIssues(issues chan *Issue) chan *Issue {
if reflect.DeepEqual([]string{"none"}, config.Sort) { if reflect.DeepEqual([]string{"none"}, config.Sort) {
return issues return issues

View file

@ -2,8 +2,10 @@ package main
import ( import (
"bytes" "bytes"
"encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"path/filepath"
"sort" "sort"
"strings" "strings"
"text/template" "text/template"
@ -21,13 +23,59 @@ const (
Warning Severity = "warning" Warning Severity = "warning"
) )
type IssuePath struct {
root string
path string
}
func (i IssuePath) String() string {
return i.Relative()
}
func (i IssuePath) Relative() string {
return i.path
}
func (i IssuePath) Abs() string {
return filepath.Join(i.root, i.path)
}
func (i IssuePath) MarshalJSON() ([]byte, error) {
return json.Marshal(i.String())
}
func newIssuePath(root, path string) IssuePath {
return IssuePath{root: root, path: path}
}
// newIssuePathFromAbsPath returns a new issuePath from a path that may be
// an absolute path. root must be an absolute path.
func newIssuePathFromAbsPath(root, path string) (IssuePath, error) {
resolvedRoot, err := filepath.EvalSymlinks(root)
if err != nil {
return newIssuePath(root, path), err
}
resolvedPath, err := filepath.EvalSymlinks(path)
if err != nil {
return newIssuePath(root, path), err
}
if !filepath.IsAbs(path) {
return newIssuePath(resolvedRoot, resolvedPath), nil
}
relPath, err := filepath.Rel(resolvedRoot, resolvedPath)
return newIssuePath(resolvedRoot, relPath), err
}
type Issue struct { type Issue struct {
Linter string `json:"linter"` Linter string `json:"linter"`
Severity Severity `json:"severity"` Severity Severity `json:"severity"`
Path string `json:"path"` Path IssuePath `json:"path"`
Line int `json:"line"` Line int `json:"line"`
Col int `json:"col"` Col int `json:"col"`
Message string `json:"message"` Message string `json:"message"`
formatTmpl *template.Template formatTmpl *template.Template
} }
@ -50,7 +98,11 @@ func (i *Issue) String() string {
if i.Col != 0 { if i.Col != 0 {
col = fmt.Sprintf("%d", i.Col) col = fmt.Sprintf("%d", i.Col)
} }
return fmt.Sprintf("%s:%d:%s:%s: %s (%s)", strings.TrimSpace(i.Path), i.Line, col, i.Severity, strings.TrimSpace(i.Message), i.Linter) return fmt.Sprintf("%s:%d:%s:%s: %s (%s)",
strings.TrimSpace(i.Path.Relative()),
i.Line, col, i.Severity,
strings.TrimSpace(i.Message),
i.Linter)
} }
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
_ = i.formatTmpl.Execute(buf, i) _ = i.formatTmpl.Execute(buf, i)
@ -76,7 +128,7 @@ func CompareIssue(l, r Issue, order []string) bool {
for _, key := range order { for _, key := range order {
switch { switch {
case key == "path" && l.Path != r.Path: case key == "path" && l.Path != r.Path:
return l.Path < r.Path return l.Path.String() < r.Path.String()
case key == "line" && l.Line != r.Line: case key == "line" && l.Line != r.Line:
return l.Line < r.Line return l.Line < r.Line
case key == "column" && l.Col != r.Col: case key == "column" && l.Col != r.Col:

View file

@ -10,10 +10,10 @@ import (
func TestSortedIssues(t *testing.T) { func TestSortedIssues(t *testing.T) {
actual := []*Issue{ actual := []*Issue{
{Path: "b.go", Line: 5, Col: 1}, {Path: newIssuePath("", "b.go"), Line: 5, Col: 1},
{Path: "a.go", Line: 3, Col: 2}, {Path: newIssuePath("", "a.go"), Line: 3, Col: 2},
{Path: "b.go", Line: 1, Col: 3}, {Path: newIssuePath("", "b.go"), Line: 1, Col: 3},
{Path: "a.go", Line: 1, Col: 4}, {Path: newIssuePath("", "a.go"), Line: 1, Col: 4},
} }
issues := &sortedIssues{ issues := &sortedIssues{
issues: actual, issues: actual,
@ -21,18 +21,18 @@ func TestSortedIssues(t *testing.T) {
} }
sort.Sort(issues) sort.Sort(issues)
expected := []*Issue{ expected := []*Issue{
{Path: "a.go", Line: 1, Col: 4}, {Path: newIssuePath("", "a.go"), Line: 1, Col: 4},
{Path: "a.go", Line: 3, Col: 2}, {Path: newIssuePath("", "a.go"), Line: 3, Col: 2},
{Path: "b.go", Line: 1, Col: 3}, {Path: newIssuePath( "", "b.go"), Line: 1, Col: 3},
{Path: "b.go", Line: 5, Col: 1}, {Path: newIssuePath( "", "b.go"), Line: 5, Col: 1},
} }
require.Equal(t, expected, actual) require.Equal(t, expected, actual)
} }
func TestCompareOrderWithMessage(t *testing.T) { func TestCompareOrderWithMessage(t *testing.T) {
order := []string{"path", "line", "column", "message"} order := []string{"path", "line", "column", "message"}
issueM := Issue{Path: "file.go", Message: "message"} issueM := Issue{Path: newIssuePath("", "file.go"), Message: "message"}
issueU := Issue{Path: "file.go", Message: "unknown"} issueU := Issue{Path: newIssuePath("", "file.go"), Message: "unknown"}
assert.True(t, CompareIssue(issueM, issueU, order)) assert.True(t, CompareIssue(issueM, issueU, order))
assert.False(t, CompareIssue(issueU, issueM, order)) assert.False(t, CompareIssue(issueU, issueM, order))

View file

@ -235,7 +235,7 @@ var defaultLinters = map[string]LinterConfig{
Command: `gas -fmt=csv`, Command: `gas -fmt=csv`,
Pattern: `^(?P<path>.*?\.go),(?P<line>\d+),(?P<message>[^,]+,[^,]+,[^,]+)`, Pattern: `^(?P<path>.*?\.go),(?P<line>\d+),(?P<message>[^,]+,[^,]+,[^,]+)`,
InstallFrom: "github.com/GoASTScanner/gas", InstallFrom: "github.com/GoASTScanner/gas",
PartitionStrategy: partitionPathsAsDirectories, PartitionStrategy: partitionPathsAsFiles,
defaultEnabled: true, defaultEnabled: true,
IsFast: true, IsFast: true,
}, },
@ -328,7 +328,7 @@ var defaultLinters = map[string]LinterConfig{
defaultEnabled: true, defaultEnabled: true,
}, },
"misspell": { "misspell": {
Command: `misspell -j 1`, Command: `misspell -j 1 --locale "{misspelllocale}"`,
Pattern: `PATH:LINE:COL:MESSAGE`, Pattern: `PATH:LINE:COL:MESSAGE`,
InstallFrom: "github.com/client9/misspell/cmd/misspell", InstallFrom: "github.com/client9/misspell/cmd/misspell",
PartitionStrategy: partitionPathsAsFiles, PartitionStrategy: partitionPathsAsFiles,

View file

@ -21,12 +21,13 @@ var (
// Locations to look for vendored linters. // Locations to look for vendored linters.
vendoredSearchPaths = [][]string{ vendoredSearchPaths = [][]string{
{"github.com", "alecthomas", "gometalinter", "_linters"}, {"github.com", "alecthomas", "gometalinter", "_linters"},
{"gopkg.in", "alecthomas", "gometalinter.v1", "_linters"}, {"gopkg.in", "alecthomas", "gometalinter.v2", "_linters"},
} }
Version = "master"
) )
func setupFlags(app *kingpin.Application) { func setupFlags(app *kingpin.Application) {
app.Flag("config", "Load JSON configuration from file.").Action(loadConfig).String() app.Flag("config", "Load JSON configuration from file.").Envar("GOMETALINTER_CONFIG").Action(loadConfig).String()
app.Flag("disable", "Disable previously enabled linters.").PlaceHolder("LINTER").Short('D').Action(disableAction).Strings() app.Flag("disable", "Disable previously enabled linters.").PlaceHolder("LINTER").Short('D').Action(disableAction).Strings()
app.Flag("enable", "Enable previously disabled linters.").PlaceHolder("LINTER").Short('E').Action(enableAction).Strings() app.Flag("enable", "Enable previously disabled linters.").PlaceHolder("LINTER").Short('E').Action(enableAction).Strings()
app.Flag("linter", "Define a linter.").PlaceHolder("NAME:COMMAND:PATTERN").Action(cliLinterOverrides).StringMap() app.Flag("linter", "Define a linter.").PlaceHolder("NAME:COMMAND:PATTERN").Action(cliLinterOverrides).StringMap()
@ -35,12 +36,12 @@ func setupFlags(app *kingpin.Application) {
app.Flag("disable-all", "Disable all linters.").Action(disableAllAction).Bool() app.Flag("disable-all", "Disable all linters.").Action(disableAllAction).Bool()
app.Flag("enable-all", "Enable all linters.").Action(enableAllAction).Bool() app.Flag("enable-all", "Enable all linters.").Action(enableAllAction).Bool()
app.Flag("format", "Output format.").PlaceHolder(config.Format).StringVar(&config.Format) app.Flag("format", "Output format.").PlaceHolder(config.Format).StringVar(&config.Format)
app.Flag("vendored-linters", "Use vendored linters (recommended).").BoolVar(&config.VendoredLinters) app.Flag("vendored-linters", "Use vendored linters (recommended) (DEPRECATED - use binary packages).").BoolVar(&config.VendoredLinters)
app.Flag("fast", "Only run fast linters.").BoolVar(&config.Fast) app.Flag("fast", "Only run fast linters.").BoolVar(&config.Fast)
app.Flag("install", "Attempt to install all known linters.").Short('i').BoolVar(&config.Install) app.Flag("install", "Attempt to install all known linters (DEPRECATED - use binary packages).").Short('i').BoolVar(&config.Install)
app.Flag("update", "Pass -u to go tool when installing.").Short('u').BoolVar(&config.Update) app.Flag("update", "Pass -u to go tool when installing (DEPRECATED - use binary packages).").Short('u').BoolVar(&config.Update)
app.Flag("force", "Pass -f to go tool when installing.").Short('f').BoolVar(&config.Force) app.Flag("force", "Pass -f to go tool when installing (DEPRECATED - use binary packages).").Short('f').BoolVar(&config.Force)
app.Flag("download-only", "Pass -d to go tool when installing.").BoolVar(&config.DownloadOnly) app.Flag("download-only", "Pass -d to go tool when installing (DEPRECATED - use binary packages).").BoolVar(&config.DownloadOnly)
app.Flag("debug", "Display messages for failed linters, etc.").Short('d').BoolVar(&config.Debug) app.Flag("debug", "Display messages for failed linters, etc.").Short('d').BoolVar(&config.Debug)
app.Flag("concurrency", "Number of concurrent linters to run.").PlaceHolder(fmt.Sprintf("%d", runtime.NumCPU())).Short('j').IntVar(&config.Concurrency) app.Flag("concurrency", "Number of concurrent linters to run.").PlaceHolder(fmt.Sprintf("%d", runtime.NumCPU())).Short('j').IntVar(&config.Concurrency)
app.Flag("exclude", "Exclude messages matching these regular expressions.").Short('e').PlaceHolder("REGEXP").StringsVar(&config.Exclude) app.Flag("exclude", "Exclude messages matching these regular expressions.").Short('e').PlaceHolder("REGEXP").StringsVar(&config.Exclude)
@ -49,6 +50,7 @@ func setupFlags(app *kingpin.Application) {
app.Flag("vendor", "Enable vendoring support (skips 'vendor' directories and sets GO15VENDOREXPERIMENT=1).").BoolVar(&config.Vendor) app.Flag("vendor", "Enable vendoring support (skips 'vendor' directories and sets GO15VENDOREXPERIMENT=1).").BoolVar(&config.Vendor)
app.Flag("cyclo-over", "Report functions with cyclomatic complexity over N (using gocyclo).").PlaceHolder("10").IntVar(&config.Cyclo) app.Flag("cyclo-over", "Report functions with cyclomatic complexity over N (using gocyclo).").PlaceHolder("10").IntVar(&config.Cyclo)
app.Flag("line-length", "Report lines longer than N (using lll).").PlaceHolder("80").IntVar(&config.LineLength) app.Flag("line-length", "Report lines longer than N (using lll).").PlaceHolder("80").IntVar(&config.LineLength)
app.Flag("misspell-locale", "Specify locale to use (using misspell).").PlaceHolder("").StringVar(&config.MisspellLocale)
app.Flag("min-confidence", "Minimum confidence interval to pass to golint.").PlaceHolder(".80").FloatVar(&config.MinConfidence) app.Flag("min-confidence", "Minimum confidence interval to pass to golint.").PlaceHolder(".80").FloatVar(&config.MinConfidence)
app.Flag("min-occurrences", "Minimum occurrences to pass to goconst.").PlaceHolder("3").IntVar(&config.MinOccurrences) app.Flag("min-occurrences", "Minimum occurrences to pass to goconst.").PlaceHolder("3").IntVar(&config.MinOccurrences)
app.Flag("min-const-length", "Minimum constant length.").PlaceHolder("3").IntVar(&config.MinConstLength) app.Flag("min-const-length", "Minimum constant length.").PlaceHolder("3").IntVar(&config.MinConstLength)
@ -156,8 +158,8 @@ func formatLinters() string {
if install == "()" { if install == "()" {
install = "" install = ""
} }
fmt.Fprintf(w, " %s %s\n %s\n %s\n", fmt.Fprintf(w, " %s: %s\n\tcommand: %s\n\tregex: %s\n\tfast: %t\n\tdefault enabled: %t\n\n",
linter.Name, install, linter.Command, linter.Pattern) linter.Name, install, linter.Command, linter.Pattern, linter.IsFast, linter.defaultEnabled)
} }
return w.String() return w.String()
} }
@ -171,6 +173,7 @@ func formatSeverity() string {
} }
func main() { func main() {
kingpin.Version(Version)
pathsArg := kingpin.Arg("path", "Directories to lint. Defaults to \".\". <path>/... will recurse.").Strings() pathsArg := kingpin.Arg("path", "Directories to lint. Defaults to \".\". <path>/... will recurse.").Strings()
app := kingpin.CommandLine app := kingpin.CommandLine
setupFlags(app) setupFlags(app)

View file

@ -0,0 +1,104 @@
#!/bin/bash -e
# Only build packages for tagged releases
TAG="$(git tag -l --points-at HEAD)"
export CGO_ENABLED=0
GO_VERSION="$(go version | awk '{print $3}' | cut -d. -f1-2)"
if [ "$GO_VERSION" != "go1.9" ]; then
echo "$0: not packaging; not on Go 1.9"
exit 0
fi
if echo "$TAG" | grep -q '^v[0-9]\.[0-9]\.[0-9]\(-.*\)?$' && false; then
echo "$0: not packaging; no tag or tag not in semver format"
exit 0
fi
LINTERS="
github.com/alecthomas/gocyclo
github.com/alexkohler/nakedret
github.com/client9/misspell/cmd/misspell
github.com/dnephin/govet
github.com/GoASTScanner/gas
github.com/golang/lint/golint
github.com/gordonklaus/ineffassign
github.com/jgautheron/goconst/cmd/goconst
github.com/kisielk/errcheck
github.com/mdempsky/maligned
github.com/mdempsky/unconvert
github.com/mibk/dupl
github.com/opennota/check/cmd/structcheck
github.com/opennota/check/cmd/varcheck
github.com/stripe/safesql
github.com/tsenart/deadcode
github.com/walle/lll/cmd/lll
golang.org/x/tools/cmd/goimports
golang.org/x/tools/cmd/gotype
honnef.co/go/tools/cmd/gosimple
honnef.co/go/tools/cmd/megacheck
honnef.co/go/tools/cmd/staticcheck
honnef.co/go/tools/cmd/unused
mvdan.cc/interfacer
mvdan.cc/unparam
"
eval "$(go env | FS='' awk '{printf "REAL_%s\n", $0}')"
function install_go_binary() {
local SRC
if [ "$GOOS" = "$REAL_GOOS" -a "$GOARCH" = "$REAL_GOARCH" ]; then
SRC="${GOPATH}/bin"
else
SRC="${GOPATH}/bin/${GOOS}_${GOARCH}"
fi
install -m 755 "${SRC}/${1}${SUFFIX}" "${2}"
}
function packager() {
if [ "$GOOS" = "windows" ]; then
zip -9 -r -o "${1}".zip "${1}"
else
tar cvfj "${1}".tar.bz2 "${1}"
fi
}
rm -rf "${PWD}/dist"
for GOOS in linux darwin windows; do
SUFFIX=""
if [ "$GOOS" = "windows" ]; then
SUFFIX=".exe"
fi
for GOARCH in 386 amd64; do
export GOPATH="${REAL_GOPATH}"
DEST="${PWD}/dist/gometalinter-${TAG}-${GOOS}-${GOARCH}"
install -d -m 755 "${DEST}/linters"
install -m 644 COPYING "${DEST}"
cat << EOF > "${DEST}/README.txt"
gometalinter is a tool to normalise the output of Go linters.
See https://github.com/alecthomas/gometalinter for more information.
This is a binary distribution of gometalinter ${TAG}.
All binaries must be installed in the PATH for gometalinter to operate correctly.
EOF
echo "${DEST}"
export GOOS GOARCH
go build -i .
go build -o "${DEST}/gometalinter${SUFFIX}" -ldflags="-X main.Version=${TAG}" .
export GOPATH="$PWD/_linters"
go install -v ${LINTERS}
for LINTER in ${LINTERS}; do
install_go_binary $(basename ${LINTER}) "${DEST}/linters"
done
(cd "${DEST}/.." && packager "$(basename ${DEST})")
done
done

View file

@ -0,0 +1,37 @@
package regressiontests
import (
"fmt"
"testing"
"github.com/gotestyourself/gotestyourself/fs"
"github.com/stretchr/testify/assert"
)
func TestGas(t *testing.T) {
t.Parallel()
dir := fs.NewDir(t, "test-gas",
fs.WithFile("file.go", gasFileErrorUnhandled("root")),
fs.WithDir("sub",
fs.WithFile("file.go", gasFileErrorUnhandled("sub"))))
defer dir.Remove()
expected := Issues{
{Linter: "gas", Severity: "warning", Path: "file.go", Line: 3, Col: 0, Message: "Errors unhandled.,LOW,HIGH"},
{Linter: "gas", Severity: "warning", Path: "sub/file.go", Line: 3, Col: 0, Message: "Errors unhandled.,LOW,HIGH"},
}
actual := RunLinter(t, "gas", dir.Path())
assert.Equal(t, expected, actual)
}
func gasFileErrorUnhandled(pkg string) string {
return fmt.Sprintf(`package %s
func badFunction() string {
u, _ := ErrorHandle()
return u
}
func ErrorHandle() (u string, err error) {
return u
}
`, pkg)
}

View file

@ -4,13 +4,11 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"regexp"
"strings" "strings"
) )
var ( var (
errCommandNotSpecified = TError("command not specified") errCommandNotSpecified = TError("command not specified")
envarTransformRegexp = regexp.MustCompile(`[^a-zA-Z_]+`)
) )
// An Application contains the definitions of flags, arguments and commands // An Application contains the definitions of flags, arguments and commands
@ -28,7 +26,9 @@ type Application struct {
errors io.Writer errors io.Writer
terminate func(status int) // See Terminate() terminate func(status int) // See Terminate()
noInterspersed bool // can flags be interspersed with args (or must they come first) noInterspersed bool // can flags be interspersed with args (or must they come first)
envarSeparator string
defaultEnvars bool defaultEnvars bool
resolvers []Resolver
completion bool completion bool
helpFlag *Clause helpFlag *Clause
helpCommand *CmdClause helpCommand *CmdClause
@ -38,11 +38,12 @@ type Application struct {
// New creates a new Kingpin application instance. // New creates a new Kingpin application instance.
func New(name, help string) *Application { func New(name, help string) *Application {
a := &Application{ a := &Application{
Name: name, Name: name,
Help: help, Help: help,
output: os.Stdout, output: os.Stdout,
errors: os.Stderr, errors: os.Stderr,
terminate: os.Exit, terminate: os.Exit,
envarSeparator: string(os.PathListSeparator),
defaultUsage: &UsageContext{ defaultUsage: &UsageContext{
Template: DefaultUsageTemplate, Template: DefaultUsageTemplate,
}, },
@ -128,6 +129,26 @@ func (a *Application) DefaultEnvars() *Application {
return a return a
} }
// EnvarSeparator sets the string that is used for separating values in environment variables.
//
// This defaults to the current OS's path list separator (typically : or ;).
func (a *Application) EnvarSeparator(sep string) *Application {
a.envarSeparator = sep
return a
}
// Resolver adds an ordered set of flag/argument resolvers.
//
// Resolvers provide default flag/argument values, from environment variables, configuration files, etc. Multiple
// resolvers may be added, and they are processed in order.
//
// The last Resolver to return a value always wins. Values returned from resolvers are not cumulative.
func (a *Application) Resolver(resolvers ...Resolver) *Application {
a.resolvers = append(a.resolvers, resolvers...)
return a
}
// Terminate specifies the termination handler. Defaults to os.Exit(status). // Terminate specifies the termination handler. Defaults to os.Exit(status).
// If nil is passed, a no-op function will be used. // If nil is passed, a no-op function will be used.
func (a *Application) Terminate(terminate func(int)) *Application { func (a *Application) Terminate(terminate func(int)) *Application {
@ -138,7 +159,7 @@ func (a *Application) Terminate(terminate func(int)) *Application {
return a return a
} }
// Writer specifies the writer to use for usage and errors. Defaults to os.Stderr. // Writers specifies the writers to use for usage and errors. Defaults to os.Stderr.
func (a *Application) Writers(out, err io.Writer) *Application { func (a *Application) Writers(out, err io.Writer) *Application {
a.output = out a.output = out
a.errors = err a.errors = err
@ -169,11 +190,27 @@ func (a *Application) parseContext(ignoreDefault bool, args []string) (*ParseCon
if err := a.init(); err != nil { if err := a.init(); err != nil {
return nil, err return nil, err
} }
context := tokenize(args, ignoreDefault) context := tokenize(args, ignoreDefault, a.buildResolvers())
err := parse(context, a) err := parse(context, a)
return context, err return context, err
} }
// Build resolvers to emulate the envar and defaults behaviour that was previously hard-coded.
func (a *Application) buildResolvers() []Resolver {
// .Default() has lowest priority...
resolvers := []Resolver{defaultsResolver()}
// Then custom resolvers...
resolvers = append(resolvers, a.resolvers...)
// Finally, envars are highest priority behind direct flag parsing.
if a.defaultEnvars {
resolvers = append(resolvers, PrefixedEnvarResolver(a.Name+"_", a.envarSeparator))
}
resolvers = append(resolvers, envarResolver(a.envarSeparator))
return resolvers
}
// Parse parses command-line arguments. It returns the selected command and an // Parse parses command-line arguments. It returns the selected command and an
// error. The selected command will be a space separated subcommand, if // error. The selected command will be a space separated subcommand, if
// subcommands have been configured. // subcommands have been configured.
@ -275,13 +312,6 @@ func (a *Application) Interspersed(interspersed bool) *Application {
return a return a
} }
func (a *Application) defaultEnvarPrefix() string {
if a.defaultEnvars {
return a.Name
}
return ""
}
func (a *Application) init() error { func (a *Application) init() error {
if a.initialized { if a.initialized {
return nil return nil
@ -308,7 +338,7 @@ func (a *Application) init() error {
a.commandOrder = append(a.commandOrder[l-1:l], a.commandOrder[:l-1]...) a.commandOrder = append(a.commandOrder[l-1:l], a.commandOrder[:l-1]...)
} }
if err := a.flagGroup.init(a.defaultEnvarPrefix()); err != nil { if err := a.flagGroup.init(); err != nil {
return err return err
} }
if err := a.cmdGroup.init(); err != nil { if err := a.cmdGroup.init(); err != nil {
@ -382,7 +412,7 @@ func (a *Application) setDefaults(context *ParseContext) error {
// Check required flags and set defaults. // Check required flags and set defaults.
for _, flag := range context.flags.long { for _, flag := range context.flags.long {
if flagElements[flag.name] == nil { if flagElements[flag.name] == nil {
if err := flag.setDefault(); err != nil { if err := flag.setDefault(context); err != nil {
return err return err
} }
} else { } else {
@ -392,7 +422,7 @@ func (a *Application) setDefaults(context *ParseContext) error {
for _, arg := range context.arguments.args { for _, arg := range context.arguments.args {
if argElements[arg.name] == nil { if argElements[arg.name] == nil {
if err := arg.setDefault(); err != nil { if err := arg.setDefault(context); err != nil {
return err return err
} }
} else { } else {
@ -411,7 +441,7 @@ func (a *Application) validateRequired(context *ParseContext) error {
for _, flag := range context.flags.long { for _, flag := range context.flags.long {
if flagElements[flag.name] == nil { if flagElements[flag.name] == nil {
// Check required flags were provided. // Check required flags were provided.
if flag.needsValue() { if flag.needsValue(context) {
return TError("required flag --{{.Arg0}} not provided", V{"Arg0": flag.name}) return TError("required flag --{{.Arg0}} not provided", V{"Arg0": flag.name})
} }
} }
@ -419,7 +449,7 @@ func (a *Application) validateRequired(context *ParseContext) error {
for _, arg := range context.arguments.args { for _, arg := range context.arguments.args {
if argElements[arg.name] == nil { if argElements[arg.name] == nil {
if arg.needsValue() { if arg.needsValue(context) {
return TError("required argument '{{.Arg0}}' not provided", V{"Arg0": arg.name}) return TError("required argument '{{.Arg0}}' not provided", V{"Arg0": arg.name})
} }
} }
@ -465,7 +495,10 @@ func (a *Application) setValues(context *ParseContext) (selected []string, err e
} }
} }
if lastCmd != nil && len(lastCmd.commands) > 0 { if lastCmd == nil || lastCmd.optionalSubcommands {
return
}
if len(lastCmd.commands) > 0 {
return nil, TError("must select a subcommand of '{{.Arg0}}'", V{"Arg0": lastCmd.FullCommand()}) return nil, TError("must select a subcommand of '{{.Arg0}}'", V{"Arg0": lastCmd.FullCommand()})
} }
@ -642,7 +675,3 @@ func (a *Application) applyActions(context *ParseContext) error {
} }
return nil return nil
} }
func envarTransform(name string) string {
return strings.ToUpper(envarTransformRegexp.ReplaceAllString(name, "_"))
}

View file

@ -2,39 +2,25 @@ package kingpin
import ( import (
"net/url" "net/url"
"os"
"regexp"
"github.com/alecthomas/units" "github.com/alecthomas/units"
) )
var (
envVarValuesSeparator = "\r?\n"
envVarValuesTrimmer = regexp.MustCompile(envVarValuesSeparator + "$")
envVarValuesSplitter = regexp.MustCompile(envVarValuesSeparator)
)
type Settings interface {
SetValue(value Value)
}
// A Clause represents a flag or an argument passed by the user. // A Clause represents a flag or an argument passed by the user.
type Clause struct { type Clause struct {
actionMixin actionMixin
completionsMixin completionsMixin
name string name string
shorthand rune shorthand rune
help string help string
placeholder string placeholder string
hidden bool hidden bool
defaultValues []string defaultValues []string
value Value value Value
required bool required bool
envar string envar string
noEnvar bool noEnvar bool
hintActions []HintAction
builtinHintActions []HintAction
} }
func NewClause(name, help string) *Clause { func NewClause(name, help string) *Clause {
@ -64,6 +50,11 @@ func (c *Clause) init() error {
return nil return nil
} }
func (c *Clause) Help(help string) *Clause {
c.help = help
return c
}
// UsageAction adds a PreAction() that will display the given UsageContext. // UsageAction adds a PreAction() that will display the given UsageContext.
func (c *Clause) UsageAction(context *UsageContext) *Clause { func (c *Clause) UsageAction(context *UsageContext) *Clause {
c.PreAction(func(a *Application, e *ParseElement, c *ParseContext) error { c.PreAction(func(a *Application, e *ParseElement, c *ParseContext) error {
@ -95,13 +86,21 @@ func (c *Clause) HintAction(action HintAction) *Clause {
return c return c
} }
func (c *Clause) addHintAction(action HintAction) { // Envar overrides the default value(s) for a flag from an environment variable,
c.hintActions = append(c.hintActions, action) // if it is set. Several default values can be provided by using new lines to
// separate them.
func (c *Clause) Envar(name string) *Clause {
c.envar = name
c.noEnvar = false
return c
} }
// Allow adding of HintActions which are added internally, ie, EnumVar // NoEnvar forces environment variable defaults to be disabled for this flag.
func (c *Clause) addHintActionBuiltin(action HintAction) { // Most useful in conjunction with PrefixedEnvarResolver.
c.builtinHintActions = append(c.builtinHintActions, action) func (c *Clause) NoEnvar() *Clause {
c.envar = ""
c.noEnvar = true
return c
} }
func (c *Clause) resolveCompletions() []string { func (c *Clause) resolveCompletions() []string {
@ -133,23 +132,6 @@ func (c *Clause) Default(values ...string) *Clause {
return c return c
} }
// Envar overrides the default value(s) for a flag from an environment variable,
// if it is set. Several default values can be provided by using new lines to
// separate them.
func (c *Clause) Envar(name string) *Clause {
c.envar = name
c.noEnvar = false
return c
}
// NoEnvar forces environment variable defaults to be disabled for this flag.
// Most useful in conjunction with app.DefaultEnvars().
func (c *Clause) NoEnvar() *Clause {
c.envar = ""
c.noEnvar = true
return c
}
// PlaceHolder sets the place-holder string used for flag values in the help. The // PlaceHolder sets the place-holder string used for flag values in the help. The
// default behaviour is to use the value provided by Default() if provided, // default behaviour is to use the value provided by Default() if provided,
// then fall back on the capitalized flag name. // then fall back on the capitalized flag name.
@ -176,9 +158,21 @@ func (c *Clause) Short(name rune) *Clause {
return c return c
} }
func (c *Clause) needsValue() bool { func (c *Clause) needsValue(context *ParseContext) bool {
haveDefault := len(c.defaultValues) > 0 return c.required && !c.canResolve(context)
return c.required && !(haveDefault || c.HasEnvarValue()) }
func (c *Clause) canResolve(context *ParseContext) bool {
for _, resolver := range context.resolvers {
rvalues, err := resolver.Resolve(c.name, context)
if err != nil {
return false
}
if rvalues != nil {
return true
}
}
return false
} }
func (c *Clause) reset() { func (c *Clause) reset() {
@ -187,60 +181,32 @@ func (c *Clause) reset() {
} }
} }
func (c *Clause) setDefault() error { func (c *Clause) setDefault(context *ParseContext) error {
if c.HasEnvarValue() { var values []string
c.reset() for _, resolver := range context.resolvers {
if v, ok := c.value.(cumulativeValue); !ok || !v.IsCumulative() { rvalues, err := resolver.Resolve(c.name, context)
// Use the value as-is if err != nil {
return c.value.Set(c.GetEnvarValue()) return err
} }
for _, value := range c.GetSplitEnvarValue() { if rvalues != nil {
values = rvalues
}
}
if values != nil {
c.reset()
for _, value := range values {
if err := c.value.Set(value); err != nil { if err := c.value.Set(value); err != nil {
return err return err
} }
} }
return nil return nil
} else if len(c.defaultValues) > 0 {
c.reset()
for _, defaultValue := range c.defaultValues {
if err := c.value.Set(defaultValue); err != nil {
return err
}
}
return nil
} }
return nil return nil
} }
func (c *Clause) HasEnvarValue() bool {
return c.GetEnvarValue() != ""
}
func (c *Clause) GetEnvarValue() string {
if c.noEnvar || c.envar == "" {
return ""
}
return os.Getenv(c.envar)
}
func (c *Clause) GetSplitEnvarValue() []string {
values := make([]string, 0)
envarValue := c.GetEnvarValue()
if envarValue == "" {
return values
}
// Split by new line to extract multiple values, if any.
trimmed := envVarValuesTrimmer.ReplaceAllString(envarValue, "")
values = append(values, envVarValuesSplitter.Split(trimmed, -1)...)
return values
}
func (c *Clause) SetValue(value Value) { func (c *Clause) SetValue(value Value) {
c.value = value c.value = value
c.setDefault()
} }
// StringMap provides key=value parsing into a map. // StringMap provides key=value parsing into a map.

View file

@ -130,16 +130,6 @@ func newCmdGroup(app *Application) *cmdGroup {
} }
} }
func (c *cmdGroup) flattenedCommands() (out []*CmdClause) {
for _, cmd := range c.commandOrder {
if len(cmd.commands) == 0 {
out = append(out, cmd)
}
out = append(out, cmd.flattenedCommands()...)
}
return
}
func (c *cmdGroup) addCommand(name, help string) *CmdClause { func (c *cmdGroup) addCommand(name, help string) *CmdClause {
cmd := newCommand(c.app, name, help) cmd := newCommand(c.app, name, help)
c.commands[name] = cmd c.commands[name] = cmd
@ -187,14 +177,15 @@ type CmdClauseValidator func(*CmdClause) error
// and either subcommands or positional arguments. // and either subcommands or positional arguments.
type CmdClause struct { type CmdClause struct {
cmdMixin cmdMixin
app *Application app *Application
name string name string
aliases []string aliases []string
help string help string
isDefault bool isDefault bool
validator CmdClauseValidator validator CmdClauseValidator
hidden bool hidden bool
completionAlts []string completionAlts []string
optionalSubcommands bool
} }
func newCommand(app *Application, name, help string) *CmdClause { func newCommand(app *Application, name, help string) *CmdClause {
@ -271,6 +262,12 @@ func (c *CmdClause) Command(name, help string) *CmdClause {
return cmd return cmd
} }
// OptionalSubcommands makes subcommands optional
func (c *CmdClause) OptionalSubcommands() *CmdClause {
c.optionalSubcommands = true
return c
}
// Default makes this command the default if commands don't match. // Default makes this command the default if commands don't match.
func (c *CmdClause) Default() *CmdClause { func (c *CmdClause) Default() *CmdClause {
c.isDefault = true c.isDefault = true
@ -302,7 +299,7 @@ func (c *cmdMixin) checkArgCommandMixing() error {
} }
func (c *CmdClause) init() error { func (c *CmdClause) init() error {
if err := c.flagGroup.init(c.app.defaultEnvarPrefix()); err != nil { if err := c.flagGroup.init(); err != nil {
return err return err
} }
if err := c.checkArgCommandMixing(); err != nil { if err := c.checkArgCommandMixing(); err != nil {

View file

@ -105,7 +105,7 @@ func main() {
if v.Format != "" { if v.Format != "" {
return v.Format return v.Format
} }
return "fmt.Sprintf(\"%v\", *f)" return "fmt.Sprintf(\"%v\", *f.v)"
}, },
"ValueName": func(v *Value) string { "ValueName": func(v *Value) string {
name := valueName(v) name := valueName(v)

View file

@ -31,14 +31,11 @@ func (f *flagGroup) Flag(name, help string) *Clause {
return flag return flag
} }
func (f *flagGroup) init(defaultEnvarPrefix string) error { func (f *flagGroup) init() error {
if err := f.checkDuplicates(); err != nil { if err := f.checkDuplicates(); err != nil {
return err return err
} }
for _, flag := range f.long { for _, flag := range f.long {
if defaultEnvarPrefix != "" && !flag.noEnvar && flag.envar == "" {
flag.envar = envarTransform(defaultEnvarPrefix + "_" + flag.name)
}
if err := flag.init(); err != nil { if err := flag.init(); err != nil {
return err return err
} }

View file

@ -13,6 +13,15 @@ type FlagGroupModel struct {
Flags []*ClauseModel Flags []*ClauseModel
} }
func (f *FlagGroupModel) FlagByName(name string) *ClauseModel {
for _, flag := range f.Flags {
if flag.Name == name {
return flag
}
}
return nil
}
func (f *FlagGroupModel) FlagSummary() string { func (f *FlagGroupModel) FlagSummary() string {
out := []string{} out := []string{}
count := 0 count := 0
@ -39,7 +48,6 @@ type ClauseModel struct {
Help string Help string
Short rune Short rune
Default []string Default []string
Envar string
PlaceHolder string PlaceHolder string
Required bool Required bool
Hidden bool Hidden bool
@ -106,6 +114,9 @@ type CmdGroupModel struct {
func (c *CmdGroupModel) FlattenedCommands() (out []*CmdModel) { func (c *CmdGroupModel) FlattenedCommands() (out []*CmdModel) {
for _, cmd := range c.Commands { for _, cmd := range c.Commands {
if cmd.OptionalSubcommands {
out = append(out, cmd)
}
if len(cmd.Commands) == 0 { if len(cmd.Commands) == 0 {
out = append(out, cmd) out = append(out, cmd)
} }
@ -115,13 +126,14 @@ func (c *CmdGroupModel) FlattenedCommands() (out []*CmdModel) {
} }
type CmdModel struct { type CmdModel struct {
Name string Name string
Aliases []string Aliases []string
Help string Help string
Depth int Depth int
Hidden bool Hidden bool
Default bool Default bool
Parent *CmdModel OptionalSubcommands bool
Parent *CmdModel
*FlagGroupModel *FlagGroupModel
*ArgGroupModel *ArgGroupModel
*CmdGroupModel *CmdGroupModel
@ -242,7 +254,6 @@ func (f *Clause) Model() *ClauseModel {
Help: f.help, Help: f.help,
Short: f.shorthand, Short: f.shorthand,
Default: f.defaultValues, Default: f.defaultValues,
Envar: f.envar,
PlaceHolder: f.placeholder, PlaceHolder: f.placeholder,
Required: f.required, Required: f.required,
Hidden: f.hidden, Hidden: f.hidden,
@ -265,15 +276,16 @@ func (c *CmdClause) Model(parent *CmdModel) *CmdModel {
depth++ depth++
} }
cmd := &CmdModel{ cmd := &CmdModel{
Name: c.name, Name: c.name,
Parent: parent, Parent: parent,
Aliases: c.aliases, Aliases: c.aliases,
Help: c.help, Help: c.help,
Depth: depth, Depth: depth,
Hidden: c.hidden, Hidden: c.hidden,
Default: c.isDefault, Default: c.isDefault,
FlagGroupModel: c.flagGroup.Model(), OptionalSubcommands: c.optionalSubcommands,
ArgGroupModel: c.argGroup.Model(), FlagGroupModel: c.flagGroup.Model(),
ArgGroupModel: c.argGroup.Model(),
} }
cmd.CmdGroupModel = c.cmdGroup.Model(cmd) cmd.CmdGroupModel = c.cmdGroup.Model(cmd)
return cmd return cmd

View file

@ -119,6 +119,7 @@ func (p ParseElements) ArgMap() map[string]*ParseElement {
// any). // any).
type ParseContext struct { type ParseContext struct {
SelectedCommand *CmdClause SelectedCommand *CmdClause
resolvers []Resolver
ignoreDefault bool ignoreDefault bool
argsOnly bool argsOnly bool
peek []*Token peek []*Token
@ -132,6 +133,34 @@ type ParseContext struct {
Elements ParseElements Elements ParseElements
} }
func (p *ParseContext) CombinedFlagsAndArgs() []*Clause {
return append(p.Args(), p.Flags()...)
}
func (p *ParseContext) Args() []*Clause {
return p.arguments.args
}
func (p *ParseContext) Flags() []*Clause {
return p.flags.flagOrder
}
// LastCmd returns true if the element is the last (sub)command being evaluated.
func (p *ParseContext) LastCmd(element *ParseElement) bool {
lastCmdIndex := -1
eIndex := -2
for i, e := range p.Elements {
if element == e {
eIndex = i
}
if e.OneOf.Cmd != nil {
lastCmdIndex = i
}
}
return lastCmdIndex == eIndex
}
func (p *ParseContext) nextArg() *Clause { func (p *ParseContext) nextArg() *Clause {
if p.argumenti >= len(p.arguments.args) { if p.argumenti >= len(p.arguments.args) {
return nil return nil
@ -154,13 +183,14 @@ func (p *ParseContext) HasTrailingArgs() bool {
return len(p.args) > 0 return len(p.args) > 0
} }
func tokenize(args []string, ignoreDefault bool) *ParseContext { func tokenize(args []string, ignoreDefault bool, resolvers []Resolver) *ParseContext {
return &ParseContext{ return &ParseContext{
ignoreDefault: ignoreDefault, ignoreDefault: ignoreDefault,
args: args, args: args,
rawArgs: args, rawArgs: args,
flags: newFlagGroup(), flags: newFlagGroup(),
arguments: newArgGroup(), arguments: newArgGroup(),
resolvers: resolvers,
} }
} }

View file

@ -0,0 +1,157 @@
package kingpin
import (
"encoding/json"
"fmt"
"os"
"regexp"
"strings"
)
var (
envarTransformRegexp = regexp.MustCompile(`[^a-zA-Z_]+`)
)
// A Resolver retrieves flag values from an external source, such as a configuration file or environment variables.
type Resolver interface {
// Resolve key in the given parse context.
//
// A nil slice should be returned if the key can not be resolved.
Resolve(key string, context *ParseContext) ([]string, error)
}
// ResolverFunc is a function that is also a Resolver.
type ResolverFunc func(key string, context *ParseContext) ([]string, error)
func (r ResolverFunc) Resolve(key string, context *ParseContext) ([]string, error) {
return r(key, context)
}
// A resolver that pulls values from the flag defaults. This resolver is always installed in the ParseContext.
func defaultsResolver() Resolver {
return ResolverFunc(func(key string, context *ParseContext) ([]string, error) {
for _, clause := range context.CombinedFlagsAndArgs() {
if clause.name == key {
return clause.defaultValues, nil
}
}
return nil, nil
})
}
func parseEnvar(envar, sep string) []string {
value, ok := os.LookupEnv(envar)
if !ok {
return nil
}
if sep == "" {
return []string{value}
}
return strings.Split(value, sep)
}
// Resolves a clause value from the envar configured on that clause, if any.
func envarResolver(sep string) Resolver {
return ResolverFunc(func(key string, context *ParseContext) ([]string, error) {
for _, clause := range context.CombinedFlagsAndArgs() {
if key == clause.name {
if clause.noEnvar || clause.envar == "" {
return nil, nil
}
return parseEnvar(clause.envar, sep), nil
}
}
return nil, nil
})
}
// MapResolver resolves values from a static map.
func MapResolver(values map[string][]string) Resolver {
return ResolverFunc(func(key string, context *ParseContext) ([]string, error) {
return values[key], nil
})
}
// JSONResolver returns a Resolver that retrieves values from a JSON source.
func JSONResolver(data []byte) (Resolver, error) {
values := map[string]interface{}{}
err := json.Unmarshal(data, &values)
if err != nil {
return nil, err
}
mapping := map[string][]string{}
for key, value := range values {
sub, err := jsonDecodeValue(value)
if err != nil {
return nil, err
}
mapping[key] = sub
}
return MapResolver(mapping), nil
}
func jsonDecodeValue(value interface{}) ([]string, error) {
switch v := value.(type) {
case []interface{}:
out := []string{}
for _, sv := range v {
next, err := jsonDecodeValue(sv)
if err != nil {
return nil, err
}
out = append(out, next...)
}
return out, nil
case string:
return []string{v}, nil
case float64:
return []string{fmt.Sprintf("%v", v)}, nil
case bool:
if v {
return []string{"true"}, nil
}
return []string{"false"}, nil
}
return nil, fmt.Errorf("unsupported JSON value %v (of type %T)", value, value)
}
// RenamingResolver creates a resolver for remapping names for a child resolver.
//
// This is useful if your configuration file uses a naming convention that does not map directly to
// flag names.
func RenamingResolver(resolver Resolver, rename func(string) string) Resolver {
return ResolverFunc(func(key string, context *ParseContext) ([]string, error) {
return resolver.Resolve(rename(key), context)
})
}
// PrefixedEnvarResolver resolves any flag/argument via environment variables.
//
// "prefix" is the common-prefix for the environment variables. "separator", is the character used to separate
// multiple values within a single envar (eg. ";")
//
// With a prefix of APP_, flags in the form --some-flag will be transformed to APP_SOME_FLAG.
func PrefixedEnvarResolver(prefix, separator string) Resolver {
return ResolverFunc(func(key string, context *ParseContext) ([]string, error) {
key = envarTransform(prefix + key)
return parseEnvar(key, separator), nil
})
}
// DontResolve returns a Resolver that will never return values for the given keys, even if provided.
func DontResolve(resolver Resolver, keys ...string) Resolver {
disabled := map[string]bool{}
for _, key := range keys {
disabled[key] = true
}
return ResolverFunc(func(key string, context *ParseContext) ([]string, error) {
if disabled[key] {
return nil, nil
}
return resolver.Resolve(key, context)
})
}
func envarTransform(name string) string {
return strings.ToUpper(envarTransformRegexp.ReplaceAllString(name, "_"))
}

View file

@ -174,7 +174,7 @@ func (c *cmdMixin) fromStruct(clause *CmdClause, v interface{}) error { // nolin
case reflect.Uint: case reflect.Uint:
clause.UintsVar(ptr.(*[]uint)) clause.UintsVar(ptr.(*[]uint))
case reflect.Uint8: case reflect.Uint8:
clause.Uint8ListVar(ptr.(*[]uint8)) clause.HexBytesVar(ptr.(*[]byte))
case reflect.Uint16: case reflect.Uint16:
clause.Uint16ListVar(ptr.(*[]uint16)) clause.Uint16ListVar(ptr.(*[]uint16))
case reflect.Uint32: case reflect.Uint32:

View file

@ -27,7 +27,7 @@ func (f *boolValue) Set(s string) error {
func (f *boolValue) Get() interface{} { return (bool)(*f.v) } func (f *boolValue) Get() interface{} { return (bool)(*f.v) }
func (f *boolValue) String() string { return fmt.Sprintf("%v", *f) } func (f *boolValue) String() string { return fmt.Sprintf("%v", *f.v) }
// Bool parses the next command-line value as bool. // Bool parses the next command-line value as bool.
func (p *Clause) Bool() (target *bool) { func (p *Clause) Bool() (target *bool) {
@ -113,7 +113,7 @@ func (f *uintValue) Set(s string) error {
func (f *uintValue) Get() interface{} { return (uint)(*f.v) } func (f *uintValue) Get() interface{} { return (uint)(*f.v) }
func (f *uintValue) String() string { return fmt.Sprintf("%v", *f) } func (f *uintValue) String() string { return fmt.Sprintf("%v", *f.v) }
// Uint parses the next command-line value as uint. // Uint parses the next command-line value as uint.
func (p *Clause) Uint() (target *uint) { func (p *Clause) Uint() (target *uint) {
@ -156,7 +156,7 @@ func (f *uint8Value) Set(s string) error {
func (f *uint8Value) Get() interface{} { return (uint8)(*f.v) } func (f *uint8Value) Get() interface{} { return (uint8)(*f.v) }
func (f *uint8Value) String() string { return fmt.Sprintf("%v", *f) } func (f *uint8Value) String() string { return fmt.Sprintf("%v", *f.v) }
// Uint8 parses the next command-line value as uint8. // Uint8 parses the next command-line value as uint8.
func (p *Clause) Uint8() (target *uint8) { func (p *Clause) Uint8() (target *uint8) {
@ -199,7 +199,7 @@ func (f *uint16Value) Set(s string) error {
func (f *uint16Value) Get() interface{} { return (uint16)(*f.v) } func (f *uint16Value) Get() interface{} { return (uint16)(*f.v) }
func (f *uint16Value) String() string { return fmt.Sprintf("%v", *f) } func (f *uint16Value) String() string { return fmt.Sprintf("%v", *f.v) }
// Uint16 parses the next command-line value as uint16. // Uint16 parses the next command-line value as uint16.
func (p *Clause) Uint16() (target *uint16) { func (p *Clause) Uint16() (target *uint16) {
@ -242,7 +242,7 @@ func (f *uint32Value) Set(s string) error {
func (f *uint32Value) Get() interface{} { return (uint32)(*f.v) } func (f *uint32Value) Get() interface{} { return (uint32)(*f.v) }
func (f *uint32Value) String() string { return fmt.Sprintf("%v", *f) } func (f *uint32Value) String() string { return fmt.Sprintf("%v", *f.v) }
// Uint32 parses the next command-line value as uint32. // Uint32 parses the next command-line value as uint32.
func (p *Clause) Uint32() (target *uint32) { func (p *Clause) Uint32() (target *uint32) {
@ -285,7 +285,7 @@ func (f *uint64Value) Set(s string) error {
func (f *uint64Value) Get() interface{} { return (uint64)(*f.v) } func (f *uint64Value) Get() interface{} { return (uint64)(*f.v) }
func (f *uint64Value) String() string { return fmt.Sprintf("%v", *f) } func (f *uint64Value) String() string { return fmt.Sprintf("%v", *f.v) }
// Uint64 parses the next command-line value as uint64. // Uint64 parses the next command-line value as uint64.
func (p *Clause) Uint64() (target *uint64) { func (p *Clause) Uint64() (target *uint64) {
@ -328,7 +328,7 @@ func (f *intValue) Set(s string) error {
func (f *intValue) Get() interface{} { return (int)(*f.v) } func (f *intValue) Get() interface{} { return (int)(*f.v) }
func (f *intValue) String() string { return fmt.Sprintf("%v", *f) } func (f *intValue) String() string { return fmt.Sprintf("%v", *f.v) }
// Int parses the next command-line value as int. // Int parses the next command-line value as int.
func (p *Clause) Int() (target *int) { func (p *Clause) Int() (target *int) {
@ -371,7 +371,7 @@ func (f *int8Value) Set(s string) error {
func (f *int8Value) Get() interface{} { return (int8)(*f.v) } func (f *int8Value) Get() interface{} { return (int8)(*f.v) }
func (f *int8Value) String() string { return fmt.Sprintf("%v", *f) } func (f *int8Value) String() string { return fmt.Sprintf("%v", *f.v) }
// Int8 parses the next command-line value as int8. // Int8 parses the next command-line value as int8.
func (p *Clause) Int8() (target *int8) { func (p *Clause) Int8() (target *int8) {
@ -414,7 +414,7 @@ func (f *int16Value) Set(s string) error {
func (f *int16Value) Get() interface{} { return (int16)(*f.v) } func (f *int16Value) Get() interface{} { return (int16)(*f.v) }
func (f *int16Value) String() string { return fmt.Sprintf("%v", *f) } func (f *int16Value) String() string { return fmt.Sprintf("%v", *f.v) }
// Int16 parses the next command-line value as int16. // Int16 parses the next command-line value as int16.
func (p *Clause) Int16() (target *int16) { func (p *Clause) Int16() (target *int16) {
@ -457,7 +457,7 @@ func (f *int32Value) Set(s string) error {
func (f *int32Value) Get() interface{} { return (int32)(*f.v) } func (f *int32Value) Get() interface{} { return (int32)(*f.v) }
func (f *int32Value) String() string { return fmt.Sprintf("%v", *f) } func (f *int32Value) String() string { return fmt.Sprintf("%v", *f.v) }
// Int32 parses the next command-line value as int32. // Int32 parses the next command-line value as int32.
func (p *Clause) Int32() (target *int32) { func (p *Clause) Int32() (target *int32) {
@ -500,7 +500,7 @@ func (f *int64Value) Set(s string) error {
func (f *int64Value) Get() interface{} { return (int64)(*f.v) } func (f *int64Value) Get() interface{} { return (int64)(*f.v) }
func (f *int64Value) String() string { return fmt.Sprintf("%v", *f) } func (f *int64Value) String() string { return fmt.Sprintf("%v", *f.v) }
// Int64 parses the next command-line value as int64. // Int64 parses the next command-line value as int64.
func (p *Clause) Int64() (target *int64) { func (p *Clause) Int64() (target *int64) {
@ -543,7 +543,7 @@ func (f *float64Value) Set(s string) error {
func (f *float64Value) Get() interface{} { return (float64)(*f.v) } func (f *float64Value) Get() interface{} { return (float64)(*f.v) }
func (f *float64Value) String() string { return fmt.Sprintf("%v", *f) } func (f *float64Value) String() string { return fmt.Sprintf("%v", *f.v) }
// Float64 parses the next command-line value as float64. // Float64 parses the next command-line value as float64.
func (p *Clause) Float64() (target *float64) { func (p *Clause) Float64() (target *float64) {
@ -586,7 +586,7 @@ func (f *float32Value) Set(s string) error {
func (f *float32Value) Get() interface{} { return (float32)(*f.v) } func (f *float32Value) Get() interface{} { return (float32)(*f.v) }
func (f *float32Value) String() string { return fmt.Sprintf("%v", *f) } func (f *float32Value) String() string { return fmt.Sprintf("%v", *f.v) }
// Float32 parses the next command-line value as float32. // Float32 parses the next command-line value as float32.
func (p *Clause) Float32() (target *float32) { func (p *Clause) Float32() (target *float32) {
@ -668,7 +668,7 @@ func (f *regexpValue) Set(s string) error {
func (f *regexpValue) Get() interface{} { return (*regexp.Regexp)(*f.v) } func (f *regexpValue) Get() interface{} { return (*regexp.Regexp)(*f.v) }
func (f *regexpValue) String() string { return fmt.Sprintf("%v", *f) } func (f *regexpValue) String() string { return fmt.Sprintf("%v", *f.v) }
// Regexp parses the next command-line value as *regexp.Regexp. // Regexp parses the next command-line value as *regexp.Regexp.
func (p *Clause) Regexp() (target **regexp.Regexp) { func (p *Clause) Regexp() (target **regexp.Regexp) {
@ -711,7 +711,7 @@ func (f *hexBytesValue) Set(s string) error {
func (f *hexBytesValue) Get() interface{} { return ([]byte)(*f.v) } func (f *hexBytesValue) Get() interface{} { return ([]byte)(*f.v) }
func (f *hexBytesValue) String() string { return fmt.Sprintf("%v", *f) } func (f *hexBytesValue) String() string { return fmt.Sprintf("%v", *f.v) }
// Bytes as a hex string. // Bytes as a hex string.
func (p *Clause) HexBytes() (target *[]byte) { func (p *Clause) HexBytes() (target *[]byte) {
@ -754,7 +754,7 @@ func (f *durationValue) Set(s string) error {
func (f *durationValue) Get() interface{} { return (time.Duration)(*f.v) } func (f *durationValue) Get() interface{} { return (time.Duration)(*f.v) }
func (f *durationValue) String() string { return fmt.Sprintf("%v", *f) } func (f *durationValue) String() string { return fmt.Sprintf("%v", *f.v) }
// Time duration. // Time duration.
func (p *Clause) Duration() (target *time.Duration) { func (p *Clause) Duration() (target *time.Duration) {

View file

@ -30,7 +30,7 @@
"importpath": "gopkg.in/alecthomas/kingpin.v3-unstable", "importpath": "gopkg.in/alecthomas/kingpin.v3-unstable",
"repository": "https://gopkg.in/alecthomas/kingpin.v3-unstable", "repository": "https://gopkg.in/alecthomas/kingpin.v3-unstable",
"vcs": "git", "vcs": "git",
"revision": "9670b87a702e049784340892f0d31a46473041c7", "revision": "63abe20a23e29e80bbef8089bd3dee3ac25e5306",
"branch": "v3-unstable", "branch": "v3-unstable",
"notests": true "notests": true
}, },

View file

@ -53,6 +53,12 @@ func (b64 Base64String) MarshalJSON() ([]byte, error) {
return json.Marshal(b64.Encode()) return json.Marshal(b64.Encode())
} }
// MarshalYAML implements yaml.Marshaller
// It just encodes the bytes as base64, which is a valid YAML string
func (b64 Base64String) MarshalYAML() (interface{}, error) {
return b64.Encode(), nil
}
// UnmarshalJSON decodes a JSON string and then decodes the resulting base64. // UnmarshalJSON decodes a JSON string and then decodes the resulting base64.
// This takes a pointer receiver because it needs to write the result of decoding. // This takes a pointer receiver because it needs to write the result of decoding.
func (b64 *Base64String) UnmarshalJSON(raw []byte) (err error) { func (b64 *Base64String) UnmarshalJSON(raw []byte) (err error) {
@ -65,3 +71,14 @@ func (b64 *Base64String) UnmarshalJSON(raw []byte) (err error) {
err = b64.Decode(str) err = b64.Decode(str)
return return
} }
// UnmarshalYAML implements yaml.Unmarshaller
// it unmarshals the input as a yaml string and then base64-decodes the result
func (b64 *Base64String) UnmarshalYAML(unmarshal func(interface{}) error) (err error) {
var str string
if err = unmarshal(&str); err != nil {
return
}
err = b64.Decode(str)
return
}

View file

@ -18,6 +18,8 @@ package gomatrixserverlib
import ( import (
"encoding/json" "encoding/json"
"testing" "testing"
"gopkg.in/yaml.v2"
) )
func TestMarshalBase64(t *testing.T) { func TestMarshalBase64(t *testing.T) {
@ -93,3 +95,58 @@ func TestMarshalBase64Slice(t *testing.T) {
t.Fatalf("json.Marshal(%v): wanted %q got %q", input, want, string(got)) t.Fatalf("json.Marshal(%v): wanted %q got %q", input, want, string(got))
} }
} }
func TestMarshalYAMLBase64(t *testing.T) {
input := Base64String("this\xffis\xffa\xfftest")
want := "dGhpc/9pc/9h/3Rlc3Q\n"
got, err := yaml.Marshal(input)
if err != nil {
t.Fatal(err)
}
if string(got) != want {
t.Fatalf("yaml.Marshal(%v): wanted %q got %q", input, want, string(got))
}
}
func TestMarshalYAMLBase64Struct(t *testing.T) {
input := struct{ Value Base64String }{Base64String("this\xffis\xffa\xfftest")}
want := "value: dGhpc/9pc/9h/3Rlc3Q\n"
got, err := yaml.Marshal(input)
if err != nil {
t.Fatal(err)
}
if string(got) != want {
t.Fatalf("yaml.Marshal(%v): wanted %q got %q", input, want, string(got))
}
}
func TestUnmarshalYAMLBase64(t *testing.T) {
input := []byte("dGhpc/9pc/9h/3Rlc3Q")
want := Base64String("this\xffis\xffa\xfftest")
var got Base64String
err := yaml.Unmarshal(input, &got)
if err != nil {
t.Fatal(err)
}
if string(got) != string(want) {
t.Fatalf("yaml.Unmarshal(%q): wanted %q got %q", string(input), want, string(got))
}
}
func TestUnmarshalYAMLBase64Struct(t *testing.T) {
// var u yaml.Unmarshaler
u := Base64String("this\xffis\xffa\xfftest")
input := []byte(`value: dGhpc/9pc/9h/3Rlc3Q`)
want := struct{ Value Base64String }{u}
result := struct {
Value Base64String `yaml:"value"`
}{}
err := yaml.Unmarshal(input, &result)
if err != nil {
t.Fatal(err)
}
if string(result.Value) != string(want.Value) {
t.Fatalf("yaml.Unmarshal(%v): wanted %q got %q", input, want, result)
}
}

View file

@ -80,7 +80,7 @@ func newFederationTripper() *federationTripper {
ServerName: "", ServerName: "",
// TODO: We should be checking that the TLS certificate we see here matches // TODO: We should be checking that the TLS certificate we see here matches
// one of the allowed SHA-256 fingerprints for the server. // one of the allowed SHA-256 fingerprints for the server.
InsecureSkipVerify: true, InsecureSkipVerify: true, // nolint: gas
}) })
if err := conn.Handshake(); err != nil { if err := conn.Handshake(); err != nil {
return nil, err return nil, err

View file

@ -22,6 +22,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
) )
@ -306,6 +307,33 @@ func (e Event) SetUnsigned(unsigned interface{}) (Event, error) {
return result, nil return result, nil
} }
// SetUnsignedField takes a path and value to insert into the unsigned dict of
// the event.
// path is a dot separated path into the unsigned dict (see gjson package
// for details on format). In particular some characters like '.' and '*' must
// be escaped.
func (e *Event) SetUnsignedField(path string, value interface{}) error {
// The safest way is to change the unsigned json and then reparse the
// event fully. But since we are only changing the unsigned section,
// which doesn't affect the signatures or hashes, we can cheat and
// just fiddle those bits directly.
path = "unsigned." + path
eventJSON, err := sjson.SetBytes(e.eventJSON, path, value)
if err != nil {
return err
}
eventJSON = CanonicalJSONAssumeValid(eventJSON)
res := gjson.GetBytes(eventJSON, "unsigned")
unsigned := rawJSONFromResult(res, eventJSON)
e.eventJSON = eventJSON
e.fields.Unsigned = unsigned
return nil
}
// EventReference returns an EventReference for the event. // EventReference returns an EventReference for the event.
// The reference can be used to refer to this event from other events. // The reference can be used to refer to this event from other events.
func (e Event) EventReference() EventReference { func (e Event) EventReference() EventReference {

View file

@ -50,3 +50,32 @@ func BenchmarkParseSmallerEventFailedHash(b *testing.B) {
func BenchmarkParseSmallerEventRedacted(b *testing.B) { func BenchmarkParseSmallerEventRedacted(b *testing.B) {
benchmarkParse(b, `{"event_id":"$yvN1b43rlmcOs5fY:localhost","sender":"@test:localhost","room_id":"!19Mp0U9hjajeIiw1:localhost","hashes":{"sha256":"Oh1mwI1jEqZ3tgJ+V1Dmu5nOEGpCE4RFUqyJv2gQXKs"},"signatures":{"localhost":{"ed25519:u9kP":"5IzSuRXkxvbTp0vZhhXYZeOe+619iG3AybJXr7zfNn/4vHz4TH7qSJVQXSaHHvcTcDodAKHnTG1WDulgO5okAQ"}},"content":{},"type":"m.room.name","state_key":"","depth":7,"prev_events":[["$FqI6TVvWpcbcnJ97:localhost",{"sha256":"upCsBqUhNUgT2/+zkzg8TbqdQpWWKQnZpGJc6KcbUC4"}]],"prev_state":[],"auth_events":[["$oXL79cT7fFxR7dPH:localhost",{"sha256":"abjkiDSg1RkuZrbj2jZoGMlQaaj1Ue3Jhi7I7NlKfXY"}],["$IVUsaSkm1LBAZYYh:localhost",{"sha256":"X7RUj46hM/8sUHNBIFkStbOauPvbDzjSdH4NibYWnko"}],["$VS2QT0EeArZYi8wf:localhost",{"sha256":"k9eM6utkCH8vhLW9/oRsH74jOBS/6RVK42iGDFbylno"}]],"origin":"localhost","origin_server_ts":1510854416361}`) benchmarkParse(b, `{"event_id":"$yvN1b43rlmcOs5fY:localhost","sender":"@test:localhost","room_id":"!19Mp0U9hjajeIiw1:localhost","hashes":{"sha256":"Oh1mwI1jEqZ3tgJ+V1Dmu5nOEGpCE4RFUqyJv2gQXKs"},"signatures":{"localhost":{"ed25519:u9kP":"5IzSuRXkxvbTp0vZhhXYZeOe+619iG3AybJXr7zfNn/4vHz4TH7qSJVQXSaHHvcTcDodAKHnTG1WDulgO5okAQ"}},"content":{},"type":"m.room.name","state_key":"","depth":7,"prev_events":[["$FqI6TVvWpcbcnJ97:localhost",{"sha256":"upCsBqUhNUgT2/+zkzg8TbqdQpWWKQnZpGJc6KcbUC4"}]],"prev_state":[],"auth_events":[["$oXL79cT7fFxR7dPH:localhost",{"sha256":"abjkiDSg1RkuZrbj2jZoGMlQaaj1Ue3Jhi7I7NlKfXY"}],["$IVUsaSkm1LBAZYYh:localhost",{"sha256":"X7RUj46hM/8sUHNBIFkStbOauPvbDzjSdH4NibYWnko"}],["$VS2QT0EeArZYi8wf:localhost",{"sha256":"k9eM6utkCH8vhLW9/oRsH74jOBS/6RVK42iGDFbylno"}]],"origin":"localhost","origin_server_ts":1510854416361}`)
} }
func TestAddUnsignedField(t *testing.T) {
initialEventJSON := `{"auth_events":[["$oXL79cT7fFxR7dPH:localhost",{"sha256":"abjkiDSg1RkuZrbj2jZoGMlQaaj1Ue3Jhi7I7NlKfXY"}],["$IVUsaSkm1LBAZYYh:localhost",{"sha256":"X7RUj46hM/8sUHNBIFkStbOauPvbDzjSdH4NibYWnko"}],["$VS2QT0EeArZYi8wf:localhost",{"sha256":"k9eM6utkCH8vhLW9/oRsH74jOBS/6RVK42iGDFbylno"}]],"content":{"name":"test3"},"depth":7,"event_id":"$yvN1b43rlmcOs5fY:localhost","hashes":{"sha256":"Oh1mwI1jEqZ3tgJ+V1Dmu5nOEGpCE4RFUqyJv2gQXKs"},"origin":"localhost","origin_server_ts":1510854416361,"prev_events":[["$FqI6TVvWpcbcnJ97:localhost",{"sha256":"upCsBqUhNUgT2/+zkzg8TbqdQpWWKQnZpGJc6KcbUC4"}]],"prev_state":[],"room_id":"!19Mp0U9hjajeIiw1:localhost","sender":"@test:localhost","signatures":{"localhost":{"ed25519:u9kP":"5IzSuRXkxvbTp0vZhhXYZeOe+619iG3AybJXr7zfNn/4vHz4TH7qSJVQXSaHHvcTcDodAKHnTG1WDulgO5okAQ"}},"state_key":"","type":"m.room.name"}`
expectedEventJSON := `{"auth_events":[["$oXL79cT7fFxR7dPH:localhost",{"sha256":"abjkiDSg1RkuZrbj2jZoGMlQaaj1Ue3Jhi7I7NlKfXY"}],["$IVUsaSkm1LBAZYYh:localhost",{"sha256":"X7RUj46hM/8sUHNBIFkStbOauPvbDzjSdH4NibYWnko"}],["$VS2QT0EeArZYi8wf:localhost",{"sha256":"k9eM6utkCH8vhLW9/oRsH74jOBS/6RVK42iGDFbylno"}]],"content":{"name":"test3"},"depth":7,"event_id":"$yvN1b43rlmcOs5fY:localhost","hashes":{"sha256":"Oh1mwI1jEqZ3tgJ+V1Dmu5nOEGpCE4RFUqyJv2gQXKs"},"origin":"localhost","origin_server_ts":1510854416361,"prev_events":[["$FqI6TVvWpcbcnJ97:localhost",{"sha256":"upCsBqUhNUgT2/+zkzg8TbqdQpWWKQnZpGJc6KcbUC4"}]],"prev_state":[],"room_id":"!19Mp0U9hjajeIiw1:localhost","sender":"@test:localhost","signatures":{"localhost":{"ed25519:u9kP":"5IzSuRXkxvbTp0vZhhXYZeOe+619iG3AybJXr7zfNn/4vHz4TH7qSJVQXSaHHvcTcDodAKHnTG1WDulgO5okAQ"}},"state_key":"","type":"m.room.name","unsigned":{"foo":"bar","x":1}}`
var event Event
if err := json.Unmarshal([]byte(initialEventJSON), &event); err != nil {
t.Error("Failed to parse event")
}
err := event.SetUnsignedField("foo", "bar")
if err != nil {
t.Error("Failed to insert foo")
}
err = event.SetUnsignedField("x", 1)
if err != nil {
t.Error("Failed to insert x")
}
bytes, err := json.Marshal(event)
if err != nil {
t.Error("Failed to marshal x")
}
if expectedEventJSON != string(bytes) {
t.Fatalf("Serialized event does not match expected: %s != %s", string(bytes), initialEventJSON)
}
}

View file

@ -179,12 +179,21 @@ func verifyEventSignature(signingName string, keyID KeyID, publicKey ed25519.Pub
// VerifyEventSignatures checks that each event in a list of events has valid // VerifyEventSignatures checks that each event in a list of events has valid
// signatures from the server that sent it. // signatures from the server that sent it.
func VerifyEventSignatures(ctx context.Context, events []Event, keyRing JSONVerifier) error { // nolint: gocyclo //
var toVerify []VerifyJSONRequest // returns an array with either an error or nil for each event.
for _, event := range events { func VerifyEventSignatures(ctx context.Context, events []Event, keyRing JSONVerifier) ([]error, error) { // nolint: gocyclo
// we will end up doing at least as many verifications as we have events.
// some events require multiple verifications, as they are signed by multiple
// servers.
toVerify := make([]VerifyJSONRequest, 0, len(events))
// for each entry in 'events', a list of corresponding indexes in toVerify
verificationMap := make([][]int, len(events))
for evtIdx, event := range events {
redactedJSON, err := redactEvent(event.eventJSON) redactedJSON, err := redactEvent(event.eventJSON)
if err != nil { if err != nil {
return err return nil, err
} }
domains := make(map[ServerName]bool) domains := make(map[ServerName]bool)
@ -203,7 +212,7 @@ func VerifyEventSignatures(ctx context.Context, events []Event, keyRing JSONVeri
// //
senderDomain, err := domainFromID(event.Sender()) senderDomain, err := domainFromID(event.Sender())
if err != nil { if err != nil {
return err return nil, err
} }
domains[ServerName(senderDomain)] = true domains[ServerName(senderDomain)] = true
@ -212,12 +221,12 @@ func VerifyEventSignatures(ctx context.Context, events []Event, keyRing JSONVeri
if event.Type() == MRoomMember && event.StateKey() != nil { if event.Type() == MRoomMember && event.StateKey() != nil {
targetDomain, err := domainFromID(*event.StateKey()) targetDomain, err := domainFromID(*event.StateKey())
if err != nil { if err != nil {
return err return nil, err
} }
if ServerName(targetDomain) != event.Origin() { if ServerName(targetDomain) != event.Origin() {
c, err := newMemberContentFromEvent(event) c, err := newMemberContentFromEvent(event)
if err != nil { if err != nil {
return err return nil, err
} }
if c.Membership == invite { if c.Membership == invite {
domains[ServerName(targetDomain)] = true domains[ServerName(targetDomain)] = true
@ -231,22 +240,45 @@ func VerifyEventSignatures(ctx context.Context, events []Event, keyRing JSONVeri
AtTS: event.OriginServerTS(), AtTS: event.OriginServerTS(),
ServerName: domain, ServerName: domain,
} }
verificationMap[evtIdx] = append(verificationMap[evtIdx], len(toVerify))
toVerify = append(toVerify, v) toVerify = append(toVerify, v)
} }
} }
results, err := keyRing.VerifyJSONs(ctx, toVerify) results, err := keyRing.VerifyJSONs(ctx, toVerify)
if err != nil { if err != nil {
return err return nil, err
} }
// Check that all the event JSON was correctly signed. // Check that all the event JSON was correctly signed
for _, result := range results { verificationErrors := make([]error, len(events))
if result.Error != nil { for evtIdx := range events {
return result.Error for _, verificationIdx := range verificationMap[evtIdx] {
result := results[verificationIdx]
if result.Error != nil {
verificationErrors[evtIdx] = result.Error
break // break inner loop; continue with outer
}
} }
} }
// Everything was okay. return verificationErrors, nil
}
// VerifyAllEventSignatures checks that each event in a list of events has valid
// signatures from the server that sent it.
//
// returns an error if any event fails verifications
func VerifyAllEventSignatures(ctx context.Context, events []Event, keyRing JSONVerifier) error {
verificationErrors, err := VerifyEventSignatures(ctx, events, keyRing)
if err != nil {
return err
}
for idx := range events {
ve := verificationErrors[idx]
if ve != nil {
return ve
}
}
return nil return nil
} }

View file

@ -272,8 +272,10 @@ func (v *StubVerifier) VerifyJSONs(ctx context.Context, requests []VerifyJSONReq
return v.results, nil return v.results, nil
} }
func TestVerifyEventSignatures(t *testing.T) { func TestVerifyAllEventSignatures(t *testing.T) {
verifier := StubVerifier{} verifier := StubVerifier{
results: make([]VerifyJSONResult, 2),
}
eventJSON := []byte(`{ eventJSON := []byte(`{
"type": "m.room.name", "type": "m.room.name",
@ -295,7 +297,7 @@ func TestVerifyEventSignatures(t *testing.T) {
event.eventJSON = eventJSON event.eventJSON = eventJSON
events := []Event{event} events := []Event{event}
if err := VerifyEventSignatures(context.Background(), events, &verifier); err != nil { if err := VerifyAllEventSignatures(context.Background(), events, &verifier); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -329,8 +331,10 @@ func TestVerifyEventSignatures(t *testing.T) {
} }
} }
func TestVerifyEventSignaturesForInvite(t *testing.T) { func TestVerifyAllEventSignaturesForInvite(t *testing.T) {
verifier := StubVerifier{} verifier := StubVerifier{
results: make([]VerifyJSONResult, 2),
}
eventJSON := []byte(`{ eventJSON := []byte(`{
"type": "m.room.member", "type": "m.room.member",
@ -352,7 +356,7 @@ func TestVerifyEventSignaturesForInvite(t *testing.T) {
event.eventJSON = eventJSON event.eventJSON = eventJSON
events := []Event{event} events := []Event{event}
if err := VerifyEventSignatures(context.Background(), events, &verifier); err != nil { if err := VerifyAllEventSignatures(context.Background(), events, &verifier); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -2,7 +2,6 @@ package gomatrixserverlib
import ( import (
"context" "context"
"net/http"
"net/url" "net/url"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
@ -22,7 +21,7 @@ func NewFederationClient(
serverName ServerName, keyID KeyID, privateKey ed25519.PrivateKey, serverName ServerName, keyID KeyID, privateKey ed25519.PrivateKey,
) *FederationClient { ) *FederationClient {
return &FederationClient{ return &FederationClient{
Client: Client{client: http.Client{Transport: newFederationTripper()}}, Client: *NewClient(),
serverName: serverName, serverName: serverName,
serverKeyID: keyID, serverKeyID: keyID,
serverPrivateKey: privateKey, serverPrivateKey: privateKey,

View file

@ -138,7 +138,7 @@ func (r RespState) Check(ctx context.Context, keyRing JSONVerifier) error {
// Check if the events pass signature checks. // Check if the events pass signature checks.
logger.Infof("Checking event signatures for %d events of room state", len(allEvents)) logger.Infof("Checking event signatures for %d events of room state", len(allEvents))
if err := VerifyEventSignatures(ctx, allEvents, keyRing); err != nil { if err := VerifyAllEventSignatures(ctx, allEvents, keyRing); err != nil {
return err return err
} }

View file

@ -110,8 +110,10 @@ func FetchKeysDirect(serverName ServerName, addr, sni string) (*ServerKeys, *tls
} }
defer tcpconn.Close() // nolint: errcheck defer tcpconn.Close() // nolint: errcheck
tlsconn := tls.Client(tcpconn, &tls.Config{ tlsconn := tls.Client(tcpconn, &tls.Config{
ServerName: sni, ServerName: sni,
InsecureSkipVerify: true, // This must be specified even though the TLS library will ignore it.
// This must be specified even though the TLS library will ignore it.
InsecureSkipVerify: true, // nolint: gas
}) })
if err = tlsconn.Handshake(); err != nil { if err = tlsconn.Handshake(); err != nil {
return nil, nil, err return nil, nil, err

View file

@ -0,0 +1,20 @@
#!/bin/sh
set -eux
cd `dirname $0`
# -u so that if this is run on a dev box, we get the latest deps, as
# we do on travis.
go get -u \
github.com/alecthomas/gometalinter \
golang.org/x/crypto/ed25519 \
github.com/matrix-org/util \
github.com/matrix-org/gomatrix \
github.com/tidwall/gjson \
github.com/tidwall/sjson \
github.com/pkg/errors \
gopkg.in/yaml.v2 \
./hooks/pre-commit