diff --git a/dm/client.go b/dm/client.go index b32bb828e..303f17a40 100644 --- a/dm/client.go +++ b/dm/client.go @@ -22,6 +22,8 @@ type Client struct { Host string // The protocol. Currently only http and https are supported. Protocol string + // Transport + Transport http.RoundTripper } // NewClient creates a new DM client. Host name is required. @@ -30,6 +32,7 @@ func NewClient(host string) *Client { HTTPTimeout: DefaultHTTPTimeout, Protocol: "https", Host: host, + Transport: NewDebugTransport(nil), } } @@ -70,7 +73,8 @@ func (c *Client) callHttp(path, method, action string, reader io.ReadCloser) (st request.Header.Add("Content-Type", "application/json") client := http.Client{ - Timeout: time.Duration(time.Duration(DefaultHTTPTimeout) * time.Second), + Timeout: time.Duration(time.Duration(DefaultHTTPTimeout) * time.Second), + Transport: c.Transport, } response, err := client.Do(request) diff --git a/dm/transport.go b/dm/transport.go new file mode 100644 index 000000000..588f7f96d --- /dev/null +++ b/dm/transport.go @@ -0,0 +1,65 @@ +package dm + +import ( + "fmt" + "io" + "net/http" + "net/http/httputil" + "os" +) + +type debugTransport struct { + // Writer is the logging destination + Writer io.Writer + + http.RoundTripper +} + +func NewDebugTransport(rt http.RoundTripper) http.RoundTripper { + return debugTransport{ + RoundTripper: rt, + Writer: os.Stderr, + } +} + +func (tr debugTransport) CancelRequest(req *http.Request) { + type canceler interface { + CancelRequest(*http.Request) + } + if cr, ok := tr.transport().(canceler); ok { + cr.CancelRequest(req) + } +} + +func (tr debugTransport) RoundTrip(req *http.Request) (*http.Response, error) { + tr.logRequest(req) + resp, err := tr.transport().RoundTrip(req) + if err != nil { + return nil, err + } + tr.logResponse(resp) + return resp, err +} + +func (tr debugTransport) transport() http.RoundTripper { + if tr.RoundTripper != nil { + return tr.RoundTripper + } + return http.DefaultTransport +} + +func (tr debugTransport) logRequest(req *http.Request) { + dump, err := httputil.DumpRequestOut(req, true) + if err != nil { + fmt.Fprintf(tr.Writer, "%s: %s\n", "could not dump request", err) + } + fmt.Fprint(tr.Writer, string(dump)) +} + +func (tr debugTransport) logResponse(resp *http.Response) { + dump, err := httputil.DumpResponse(resp, true) + if err != nil { + fmt.Fprintf(tr.Writer, "%s: %s\n", "could not dump response", err) + } + fmt.Fprint(tr.Writer, string(dump)) +} diff --git a/dm/transport_test.go b/dm/transport_test.go new file mode 100644 index 000000000..d563d3eab --- /dev/null +++ b/dm/transport_test.go @@ -0,0 +1,49 @@ +package dm + +import ( + "bytes" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestDebugTransport(t *testing.T) { + handler := func(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status":"awesome"}`)) + } + + server := httptest.NewServer(http.HandlerFunc(handler)) + defer server.Close() + + var output bytes.Buffer + + client := &http.Client{ + Transport: debugTransport{ + Writer: &output, + }, + } + + _, err := client.Get(server.URL) + if err != nil { + t.Fatal(err.Error()) + } + + expected := []string{ + "GET / HTTP/1.1", + "Accept-Encoding: gzip", + "HTTP/1.1 200 OK", + "Content-Length: 20", + "Content-Type: application/json", + `{"status":"awesome"}`, + } + actual := output.String() + + for _, match := range expected { + if !strings.Contains(actual, match) { + t.Errorf("Expected %s to contain %s", actual, match) + } + } +}