diff --git a/pkg/httputil/encoder.go b/pkg/httputil/encoder.go index df5cfb428..c491a4449 100644 --- a/pkg/httputil/encoder.go +++ b/pkg/httputil/encoder.go @@ -20,6 +20,7 @@ import ( "encoding/json" "errors" "fmt" + "io/ioutil" "mime" "net/http" "reflect" @@ -29,7 +30,10 @@ import ( ) // DefaultEncoder is an *AcceptEncoder with the default application/json encoding. -var DefaultEncoder = &AcceptEncoder{DefaultEncoding: "application/json"} +var DefaultEncoder = &AcceptEncoder{DefaultEncoding: "application/json", MaxReadLen: DefaultMaxReadLen} + +// DefaultMaxReadLen is the default maximum length to accept in an HTTP request body. +var DefaultMaxReadLen int64 = 1024 * 1024 // Encoder takes input and translate it to an expected encoded output. // @@ -46,6 +50,19 @@ type Encoder interface { // // The integer must be a valid http.Status* status code. Encode(http.ResponseWriter, *http.Request, interface{}, int) + + // Decode reads and decodes a request body. + Decode(http.ResponseWriter, *http.Request, interface{}) error +} + +// Decode decodes a request body using the DefaultEncoder. +func Decode(w http.ResponseWriter, r *http.Request, v interface{}) error { + return DefaultEncoder.Decode(w, r, v) +} + +// Encode encodes a request body using the DefaultEncoder. +func Encode(w http.ResponseWriter, r *http.Request, v interface{}, statusCode int) { + DefaultEncoder.Encode(w, r, v, statusCode) } // AcceptEncoder uses the accept headers on a request to determine the response type. @@ -58,6 +75,7 @@ type Encoder interface { // It treats `application/x-yaml` as `text/yaml`. type AcceptEncoder struct { DefaultEncoding string + MaxReadLen int64 } // Encode encodeds the given interface to the first available type in the Accept header. @@ -79,6 +97,35 @@ func (e *AcceptEncoder) Encode(w http.ResponseWriter, r *http.Request, out inter w.Write(data) } +// Decode decodes the given request into the given interface. +// +// It selects the marshal based on the value of the Content-Type header. If no +// viable decoder is found, it attempts to use the DefaultEncoder. +func (e *AcceptEncoder) Decode(w http.ResponseWriter, r *http.Request, v interface{}) error { + if e.MaxReadLen > 0 && r.ContentLength > int64(e.MaxReadLen) { + RequestEntityTooLarge(w, r, fmt.Sprintf("Max len is %d, submitted len is %d.", e.MaxReadLen, r.ContentLength)) + } + data, err := ioutil.ReadAll(r.Body) + r.Body.Close() + if err != nil { + return err + } + + ct := r.Header.Get("content-type") + mt, _, err := mime.ParseMediaType(ct) + if err != nil { + mt = "application/x-octet-stream" + } + + for n, fn := range decoders { + if n == mt { + return fn(data, v) + } + } + + return decoders[e.DefaultEncoding](data, v) +} + // parseAccept parses the value of an Accept: header and returns the best match. // // This returns the matched MIME type and the Marshal function. @@ -100,6 +147,9 @@ func (e *AcceptEncoder) parseAccept(h string) (string, Marshaler) { // Marshaler marshals an interface{} into a []byte. type Marshaler func(interface{}) ([]byte, error) +// Unmarshaler unmarshals []byte to an interface{}. +type Unmarshaler func([]byte, interface{}) error + var encoders = map[string]Marshaler{ "application/json": json.Marshal, "text/yaml": yaml.Marshal, @@ -107,6 +157,12 @@ var encoders = map[string]Marshaler{ "text/plain": textMarshal, } +var decoders = map[string]Unmarshaler{ + "application/json": json.Unmarshal, + "text/yaml": yaml.Unmarshal, + "application/x-yaml": yaml.Unmarshal, +} + // ErrUnsupportedKind indicates that the marshal cannot marshal a particular Go Kind (e.g. struct or chan). var ErrUnsupportedKind = errors.New("unsupported kind") diff --git a/pkg/httputil/httperrors.go b/pkg/httputil/httperrors.go index 2a2577b39..551dc8f54 100644 --- a/pkg/httputil/httperrors.go +++ b/pkg/httputil/httperrors.go @@ -54,6 +54,12 @@ func NotFound(w http.ResponseWriter, r *http.Request) { writeErr(w, r, msg, http.StatusNotFound) } +// RequestEntityTooLarge writes a 413 to the client and logs an error. +func RequestEntityTooLarge(w http.ResponseWriter, r *http.Request, msg string) { + log.Println(msg) + writeErr(w, r, msg, http.StatusRequestEntityTooLarge) +} + // BadRequest writes an HTTP 400. func BadRequest(w http.ResponseWriter, r *http.Request, err error) { log.Printf(LogBadRequest, r.Method, r.URL, err)