dendrite/vendor/src/github.com/jaegertracing/jaeger-client-go/crossdock/server/server.go
2018-06-18 05:44:25 -07:00

164 lines
5 KiB
Go

// Copyright (c) 2017 Uber Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"strings"
"sync"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"github.com/uber/tchannel-go"
"golang.org/x/net/context"
"github.com/uber/jaeger-client-go/crossdock/common"
"github.com/uber/jaeger-client-go/crossdock/endtoend"
"github.com/uber/jaeger-client-go/crossdock/log"
"github.com/uber/jaeger-client-go/crossdock/thrift/tracetest"
)
// Server implements S1-S3 servers
type Server struct {
HostPortHTTP string
HostPortTChannel string
Tracer opentracing.Tracer
listener net.Listener
channel *tchannel.Channel
eHandler *endtoend.Handler
}
// Start starts the test server called by the Client and other upstream servers.
func (s *Server) Start() error {
if s.HostPortHTTP == "" {
s.HostPortHTTP = ":" + common.DefaultServerPortHTTP
}
if s.HostPortTChannel == "" {
s.HostPortTChannel = ":" + common.DefaultServerPortTChannel
}
if err := s.startTChannelServer(s.Tracer); err != nil {
return err
}
s.eHandler = endtoend.NewHandler()
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { return }) // health check
mux.HandleFunc("/start_trace", func(w http.ResponseWriter, r *http.Request) {
s.handleJSON(w, r, func() interface{} {
return tracetest.NewStartTraceRequest()
}, func(ctx context.Context, req interface{}) (interface{}, error) {
return s.doStartTrace(req.(*tracetest.StartTraceRequest))
})
})
mux.HandleFunc("/join_trace", func(w http.ResponseWriter, r *http.Request) {
s.handleJSON(w, r, func() interface{} {
return tracetest.NewJoinTraceRequest()
}, func(ctx context.Context, req interface{}) (interface{}, error) {
return s.doJoinTrace(ctx, req.(*tracetest.JoinTraceRequest))
})
})
mux.HandleFunc("/create_traces", s.eHandler.GenerateTraces)
listener, err := net.Listen("tcp", s.HostPortHTTP)
if err != nil {
return err
}
s.listener = listener
s.HostPortHTTP = listener.Addr().String()
var started sync.WaitGroup
started.Add(1)
go func() {
started.Done()
http.Serve(listener, mux)
}()
started.Wait()
log.Printf("Started http server at %s\n", s.HostPortHTTP)
return nil
}
// URL returns URL of the HTTP server
func (s *Server) URL() string {
return fmt.Sprintf("http://%s/", s.HostPortHTTP)
}
// Close stops the server
func (s *Server) Close() error {
return s.listener.Close()
}
// GetPortHTTP returns the network port the server listens to.
func (s *Server) GetPortHTTP() string {
hostPort := s.HostPortHTTP
hostPortSplit := strings.Split(hostPort, ":")
port := hostPortSplit[len(hostPortSplit)-1]
return port
}
// GetPortTChannel returns the actual port the server listens to
func (s *Server) GetPortTChannel() string {
hostPortSplit := strings.Split(s.HostPortTChannel, ":")
port := hostPortSplit[len(hostPortSplit)-1]
return port
}
func (s *Server) handleJSON(
w http.ResponseWriter,
r *http.Request,
newReq func() interface{},
handle func(ctx context.Context, req interface{}) (interface{}, error),
) {
spanCtx, err := s.Tracer.Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(r.Header))
if err != nil && err != opentracing.ErrSpanContextNotFound {
http.Error(w, fmt.Sprintf("Cannot read request body: %+v", err), http.StatusBadRequest)
return
}
span := s.Tracer.StartSpan("post", ext.RPCServerOption(spanCtx))
ctx := opentracing.ContextWithSpan(context.Background(), span)
defer span.Finish()
body, err := ioutil.ReadAll(r.Body)
if err != nil {
http.Error(w, fmt.Sprintf("Cannot read request body: %+v", err), http.StatusInternalServerError)
return
}
log.Printf("Server request: %s", string(body))
req := newReq()
if err := json.Unmarshal(body, req); err != nil {
http.Error(w, fmt.Sprintf("Cannot parse request JSON: %+v. body=[%s]", err, string(body)), http.StatusBadRequest)
return
}
resp, err := handle(ctx, req)
if err != nil {
log.Printf("Handle error: %s", err.Error())
http.Error(w, fmt.Sprintf("Execution error: %+v", err), http.StatusInternalServerError)
return
}
json, err := json.Marshal(resp)
if err != nil {
http.Error(w, fmt.Sprintf("Cannot marshall response to JSON: %+v", err), http.StatusInternalServerError)
return
}
log.Printf("Server response: %s", string(json))
w.Header().Add("Content-Type", "application/json")
if _, err := w.Write(json); err != nil {
return
}
}