Merge pull request #396 from technosophos/fix/use-decoder

fix(manager): use decoder
pull/403/head
Matt Butcher 9 years ago
commit 3efce30ef9

@ -17,7 +17,6 @@ limitations under the License.
package main package main
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -298,19 +297,11 @@ func getPathVariable(w http.ResponseWriter, r *http.Request, variable, handler s
func getTemplate(w http.ResponseWriter, r *http.Request, handler string) *common.Template { func getTemplate(w http.ResponseWriter, r *http.Request, handler string) *common.Template {
util.LogHandlerEntry(handler, r) util.LogHandlerEntry(handler, r)
j, err := getJSONFromRequest(w, r, handler)
if err != nil {
return nil
}
t := &common.Template{} t := &common.Template{}
if err := json.Unmarshal(j, t); err != nil { if err := httputil.Decode(w, r, t); err != nil {
e := fmt.Errorf("%v\n%v", err, string(j)) httputil.BadRequest(w, r, err)
util.LogAndReturnError(handler, http.StatusBadRequest, e, w)
return nil return nil
} }
return t return t
} }
@ -479,18 +470,12 @@ func getRegistryHandlerFunc(w http.ResponseWriter, r *http.Request, c *router.Co
func getRegistry(w http.ResponseWriter, r *http.Request, handler string) *common.Registry { func getRegistry(w http.ResponseWriter, r *http.Request, handler string) *common.Registry {
util.LogHandlerEntry(handler, r) util.LogHandlerEntry(handler, r)
j, err := getJSONFromRequest(w, r, handler)
if err != nil {
return nil
}
t := &common.Registry{} t := &common.Registry{}
if err := json.Unmarshal(j, t); err != nil { if err := httputil.Decode(w, r, t); err != nil {
e := fmt.Errorf("%v\n%v", err, string(j)) httputil.BadRequest(w, r, err)
util.LogAndReturnError(handler, http.StatusBadRequest, e, w)
return nil return nil
} }
return t return t
} }
@ -611,18 +596,11 @@ func getFileHandlerFunc(w http.ResponseWriter, r *http.Request, c *router.Contex
func getCredential(w http.ResponseWriter, r *http.Request, handler string) *common.RegistryCredential { func getCredential(w http.ResponseWriter, r *http.Request, handler string) *common.RegistryCredential {
util.LogHandlerEntry(handler, r) util.LogHandlerEntry(handler, r)
j, err := getJSONFromRequest(w, r, handler)
if err != nil {
return nil
}
t := &common.RegistryCredential{} t := &common.RegistryCredential{}
if err := json.Unmarshal(j, t); err != nil { if err := httputil.Decode(w, r, t); err != nil {
e := fmt.Errorf("%v\n%v", err, string(j)) httputil.BadRequest(w, r, err)
util.LogAndReturnError(handler, http.StatusBadRequest, e, w)
return nil return nil
} }
return t return t
} }

@ -17,14 +17,15 @@ limitations under the License.
package main package main
import ( import (
"github.com/kubernetes/helm/cmd/manager/router"
"github.com/kubernetes/helm/pkg/version"
"flag" "flag"
"fmt" "fmt"
"log" "log"
"net/http" "net/http"
"os" "os"
"github.com/kubernetes/helm/cmd/manager/router"
"github.com/kubernetes/helm/pkg/httputil"
"github.com/kubernetes/helm/pkg/version"
) )
var ( var (
@ -52,6 +53,8 @@ func main() {
os.Exit(1) os.Exit(1)
} }
httputil.DefaultEncoder.MaxReadLen = c.Config.MaxTemplateLength
// Set up routes // Set up routes
handler := router.NewHandler(c) handler := router.NewHandler(c)
registerDeploymentRoutes(c, handler) registerDeploymentRoutes(c, handler)

@ -20,6 +20,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"mime" "mime"
"net/http" "net/http"
"reflect" "reflect"
@ -29,7 +30,10 @@ import (
) )
// DefaultEncoder is an *AcceptEncoder with the default application/json encoding. // DefaultEncoder is an *AcceptEncoder with the default application/json encoding.
var DefaultEncoder = &AcceptEncoder{DefaultEncoding: "application/json"} var DefaultEncoder = &AcceptEncoder{DefaultEncoding: "application/json", MaxReadLen: DefaultMaxReadLen}
// DefaultMaxReadLen is the default maximum length to accept in an HTTP request body.
var DefaultMaxReadLen int64 = 1024 * 1024
// Encoder takes input and translate it to an expected encoded output. // Encoder takes input and translate it to an expected encoded output.
// //
@ -46,6 +50,19 @@ type Encoder interface {
// //
// The integer must be a valid http.Status* status code. // The integer must be a valid http.Status* status code.
Encode(http.ResponseWriter, *http.Request, interface{}, int) Encode(http.ResponseWriter, *http.Request, interface{}, int)
// Decode reads and decodes a request body.
Decode(http.ResponseWriter, *http.Request, interface{}) error
}
// Decode decodes a request body using the DefaultEncoder.
func Decode(w http.ResponseWriter, r *http.Request, v interface{}) error {
return DefaultEncoder.Decode(w, r, v)
}
// Encode encodes a request body using the DefaultEncoder.
func Encode(w http.ResponseWriter, r *http.Request, v interface{}, statusCode int) {
DefaultEncoder.Encode(w, r, v, statusCode)
} }
// AcceptEncoder uses the accept headers on a request to determine the response type. // AcceptEncoder uses the accept headers on a request to determine the response type.
@ -58,6 +75,7 @@ type Encoder interface {
// It treats `application/x-yaml` as `text/yaml`. // It treats `application/x-yaml` as `text/yaml`.
type AcceptEncoder struct { type AcceptEncoder struct {
DefaultEncoding string DefaultEncoding string
MaxReadLen int64
} }
// Encode encodeds the given interface to the first available type in the Accept header. // Encode encodeds the given interface to the first available type in the Accept header.
@ -79,6 +97,35 @@ func (e *AcceptEncoder) Encode(w http.ResponseWriter, r *http.Request, out inter
w.Write(data) w.Write(data)
} }
// Decode decodes the given request into the given interface.
//
// It selects the marshal based on the value of the Content-Type header. If no
// viable decoder is found, it attempts to use the DefaultEncoder.
func (e *AcceptEncoder) Decode(w http.ResponseWriter, r *http.Request, v interface{}) error {
if e.MaxReadLen > 0 && r.ContentLength > int64(e.MaxReadLen) {
RequestEntityTooLarge(w, r, fmt.Sprintf("Max len is %d, submitted len is %d.", e.MaxReadLen, r.ContentLength))
}
data, err := ioutil.ReadAll(r.Body)
r.Body.Close()
if err != nil {
return err
}
ct := r.Header.Get("content-type")
mt, _, err := mime.ParseMediaType(ct)
if err != nil {
mt = "application/x-octet-stream"
}
for n, fn := range decoders {
if n == mt {
return fn(data, v)
}
}
return decoders[e.DefaultEncoding](data, v)
}
// parseAccept parses the value of an Accept: header and returns the best match. // parseAccept parses the value of an Accept: header and returns the best match.
// //
// This returns the matched MIME type and the Marshal function. // This returns the matched MIME type and the Marshal function.
@ -100,6 +147,9 @@ func (e *AcceptEncoder) parseAccept(h string) (string, Marshaler) {
// Marshaler marshals an interface{} into a []byte. // Marshaler marshals an interface{} into a []byte.
type Marshaler func(interface{}) ([]byte, error) type Marshaler func(interface{}) ([]byte, error)
// Unmarshaler unmarshals []byte to an interface{}.
type Unmarshaler func([]byte, interface{}) error
var encoders = map[string]Marshaler{ var encoders = map[string]Marshaler{
"application/json": json.Marshal, "application/json": json.Marshal,
"text/yaml": yaml.Marshal, "text/yaml": yaml.Marshal,
@ -107,6 +157,12 @@ var encoders = map[string]Marshaler{
"text/plain": textMarshal, "text/plain": textMarshal,
} }
var decoders = map[string]Unmarshaler{
"application/json": json.Unmarshal,
"text/yaml": yaml.Unmarshal,
"application/x-yaml": yaml.Unmarshal,
}
// ErrUnsupportedKind indicates that the marshal cannot marshal a particular Go Kind (e.g. struct or chan). // ErrUnsupportedKind indicates that the marshal cannot marshal a particular Go Kind (e.g. struct or chan).
var ErrUnsupportedKind = errors.New("unsupported kind") var ErrUnsupportedKind = errors.New("unsupported kind")

@ -17,6 +17,7 @@ limitations under the License.
package httputil package httputil
import ( import (
"bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"io/ioutil" "io/ioutil"
@ -72,7 +73,58 @@ func TestTextMarshal(t *testing.T) {
} }
} }
func TestAcceptEncoder(t *testing.T) { type encDec struct {
Name string
}
func TestDefaultEncoder(t *testing.T) {
in := &encDec{Name: "Foo"}
var out, out2 encDec
fn := func(w http.ResponseWriter, r *http.Request) {
if err := Decode(w, r, &out); err != nil {
t.Fatalf("Failed to decode data: %s", err)
}
if out.Name != in.Name {
t.Fatalf("Expected %q, got %q", in.Name, out.Name)
}
Encode(w, r, out, http.StatusOK)
}
s := httptest.NewServer(http.HandlerFunc(fn))
defer s.Close()
data, err := json.Marshal(in)
if err != nil {
t.Fatalf("Failed to marshal JSON: %s", err)
}
req, err := http.NewRequest("GET", s.URL, bytes.NewBuffer(data))
if err != nil {
t.Fatal(err)
}
req.Header.Set("content-type", "application/json")
res, err := http.DefaultClient.Do(req)
if err != nil {
t.Errorf("Failed request: %s", err)
}
if res.StatusCode != http.StatusOK {
t.Errorf("Expected 200, got %d", res.StatusCode)
}
data, err = ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Fatal(err)
}
if err := json.Unmarshal(data, &out2); err != nil {
t.Fatal(err)
}
if out2.Name != in.Name {
t.Errorf("Expected final output to have name %q, got %q", in.Name, out2.Name)
}
}
func TestAcceptEncoderEncoder(t *testing.T) {
enc := &AcceptEncoder{ enc := &AcceptEncoder{
DefaultEncoding: "application/json", DefaultEncoding: "application/json",
} }

@ -38,8 +38,8 @@ const (
// For example, and error can be serialized to JSON or YAML. Likewise, the // For example, and error can be serialized to JSON or YAML. Likewise, the
// string marshal can convert it to a string. // string marshal can convert it to a string.
type Error struct { type Error struct {
Msg string `json:"message, omitempty"`
Status string `json:"status"` Status string `json:"status"`
Msg string `json:"message, omitempty"`
} }
// Error implements the error interface. // Error implements the error interface.
@ -54,6 +54,12 @@ func NotFound(w http.ResponseWriter, r *http.Request) {
writeErr(w, r, msg, http.StatusNotFound) writeErr(w, r, msg, http.StatusNotFound)
} }
// RequestEntityTooLarge writes a 413 to the client and logs an error.
func RequestEntityTooLarge(w http.ResponseWriter, r *http.Request, msg string) {
log.Println(msg)
writeErr(w, r, msg, http.StatusRequestEntityTooLarge)
}
// BadRequest writes an HTTP 400. // BadRequest writes an HTTP 400.
func BadRequest(w http.ResponseWriter, r *http.Request, err error) { func BadRequest(w http.ResponseWriter, r *http.Request, err error) {
log.Printf(LogBadRequest, r.Method, r.URL, err) log.Printf(LogBadRequest, r.Method, r.URL, err)

Loading…
Cancel
Save