diff --git a/pkg/tiller/release_history.go b/pkg/tiller/release_history.go index 341079575..63f4275e9 100644 --- a/pkg/tiller/release_history.go +++ b/pkg/tiller/release_history.go @@ -25,10 +25,6 @@ import ( // GetHistory gets the history for a given release. func (s *ReleaseServer) GetHistory(ctx context.Context, req *tpb.GetHistoryRequest) (*tpb.GetHistoryResponse, error) { - if !checkClientVersion(ctx) { - return nil, errIncompatibleVersion - } - h, err := s.env.Releases.History(req.Name) if err != nil { return nil, err diff --git a/pkg/tiller/release_server.go b/pkg/tiller/release_server.go index 1c3c65efb..30bebe636 100644 --- a/pkg/tiller/release_server.go +++ b/pkg/tiller/release_server.go @@ -27,8 +27,6 @@ import ( "github.com/technosophos/moniker" ctx "golang.org/x/net/context" - "google.golang.org/grpc" - "google.golang.org/grpc/metadata" "k8s.io/kubernetes/pkg/api/unversioned" "k8s.io/kubernetes/pkg/client/clientset_generated/internalclientset" "k8s.io/kubernetes/pkg/client/typed/discovery" @@ -65,8 +63,6 @@ var ( errMissingRelease = errors.New("no release provided") // errInvalidRevision indicates that an invalid release revision number was provided. errInvalidRevision = errors.New("invalid release revision") - // errIncompatibleVersion indicates incompatible client/server versions. - errIncompatibleVersion = errors.New("client version is incompatible") ) // ListDefaultLimit is the default limit for number of items returned in a list. @@ -83,17 +79,6 @@ var ListDefaultLimit int64 = 512 // prevents an empty string from matching. var ValidName = regexp.MustCompile("^(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])+$") -// maxMsgSize use 10MB as the default message size limit. -// grpc library default is 4MB -var maxMsgSize = 1024 * 1024 * 10 - -// NewServer creates a new grpc server. -func NewServer() *grpc.Server { - return grpc.NewServer( - grpc.MaxMsgSize(maxMsgSize), - ) -} - // ReleaseServer implements the server-side gRPC endpoint for the HAPI services. type ReleaseServer struct { env *environment.Environment @@ -108,21 +93,8 @@ func NewReleaseServer(env *environment.Environment, clientset internalclientset. } } -func getVersion(c ctx.Context) string { - if md, ok := metadata.FromContext(c); ok { - if v, ok := md["x-helm-api-client"]; ok { - return v[0] - } - } - return "" -} - // ListReleases lists the releases found by the server. func (s *ReleaseServer) ListReleases(req *services.ListReleasesRequest, stream services.ReleaseService_ListReleasesServer) error { - if !checkClientVersion(stream.Context()) { - return errIncompatibleVersion - } - if len(req.StatusCodes) == 0 { req.StatusCodes = []release.Status_Code{release.Status_DEPLOYED} } @@ -226,17 +198,8 @@ func (s *ReleaseServer) GetVersion(c ctx.Context, req *services.GetVersionReques return &services.GetVersionResponse{Version: v}, nil } -func checkClientVersion(c ctx.Context) bool { - v := getVersion(c) - return version.IsCompatible(v, version.Version) -} - // GetReleaseStatus gets the status information for a named release. func (s *ReleaseServer) GetReleaseStatus(c ctx.Context, req *services.GetReleaseStatusRequest) (*services.GetReleaseStatusResponse, error) { - if !checkClientVersion(c) { - return nil, errIncompatibleVersion - } - if !ValidName.MatchString(req.Name) { return nil, errMissingRelease } @@ -283,10 +246,6 @@ func (s *ReleaseServer) GetReleaseStatus(c ctx.Context, req *services.GetRelease // GetReleaseContent gets all of the stored information for the given release. func (s *ReleaseServer) GetReleaseContent(c ctx.Context, req *services.GetReleaseContentRequest) (*services.GetReleaseContentResponse, error) { - if !checkClientVersion(c) { - return nil, errIncompatibleVersion - } - if !ValidName.MatchString(req.Name) { return nil, errMissingRelease } @@ -302,10 +261,6 @@ func (s *ReleaseServer) GetReleaseContent(c ctx.Context, req *services.GetReleas // UpdateRelease takes an existing release and new information, and upgrades the release. func (s *ReleaseServer) UpdateRelease(c ctx.Context, req *services.UpdateReleaseRequest) (*services.UpdateReleaseResponse, error) { - if !checkClientVersion(c) { - return nil, errIncompatibleVersion - } - currentRelease, updatedRelease, err := s.prepareUpdate(req) if err != nil { return nil, err @@ -441,10 +396,6 @@ func (s *ReleaseServer) prepareUpdate(req *services.UpdateReleaseRequest) (*rele // RollbackRelease rolls back to a previous version of the given release. func (s *ReleaseServer) RollbackRelease(c ctx.Context, req *services.RollbackReleaseRequest) (*services.RollbackReleaseResponse, error) { - if !checkClientVersion(c) { - return nil, errIncompatibleVersion - } - currentRelease, targetRelease, err := s.prepareRollback(req) if err != nil { return nil, err @@ -618,10 +569,6 @@ func (s *ReleaseServer) engine(ch *chart.Chart) environment.Engine { // InstallRelease installs a release and stores the release record. func (s *ReleaseServer) InstallRelease(c ctx.Context, req *services.InstallReleaseRequest) (*services.InstallReleaseResponse, error) { - if !checkClientVersion(c) { - return nil, errIncompatibleVersion - } - rel, err := s.prepareRelease(req) if err != nil { log.Printf("Failed install prepare step: %s", err) @@ -927,10 +874,6 @@ func (s *ReleaseServer) purgeReleases(rels ...*release.Release) error { // UninstallRelease deletes all of the resources associated with this release, and marks the release DELETED. func (s *ReleaseServer) UninstallRelease(c ctx.Context, req *services.UninstallReleaseRequest) (*services.UninstallReleaseResponse, error) { - if !checkClientVersion(c) { - return nil, errIncompatibleVersion - } - if !ValidName.MatchString(req.Name) { log.Printf("uninstall: Release not found: %s", req.Name) return nil, errMissingRelease diff --git a/pkg/tiller/server.go b/pkg/tiller/server.go new file mode 100644 index 000000000..6cecda70f --- /dev/null +++ b/pkg/tiller/server.go @@ -0,0 +1,89 @@ +/* +Copyright 2016 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tiller + +import ( + "fmt" + "log" + "strings" + + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + "k8s.io/helm/pkg/version" +) + +// maxMsgSize use 10MB as the default message size limit. +// grpc library default is 4MB +var maxMsgSize = 1024 * 1024 * 10 + +// NewServer creates a new grpc server. +func NewServer() *grpc.Server { + return grpc.NewServer( + grpc.MaxMsgSize(maxMsgSize), + grpc.UnaryInterceptor(newUnaryInterceptor()), + grpc.StreamInterceptor(newStreamInterceptor()), + ) +} + +func newUnaryInterceptor() grpc.UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + if err := checkClientVersion(ctx); err != nil { + // whitelist GetVersion() from the version check + if _, m := splitMethod(info.FullMethod); m != "GetVersion" { + log.Println(err) + return nil, err + } + } + return handler(ctx, req) + } +} + +func newStreamInterceptor() grpc.StreamServerInterceptor { + return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if err := checkClientVersion(ss.Context()); err != nil { + log.Println(err) + return err + } + return handler(srv, ss) + } +} + +func splitMethod(fullMethod string) (string, string) { + if frags := strings.Split(fullMethod, "/"); len(frags) == 3 { + return frags[1], frags[2] + } + return "unknown", "unknown" +} + +func versionFromContext(ctx context.Context) string { + if md, ok := metadata.FromContext(ctx); ok { + if v, ok := md["x-helm-api-client"]; ok && len(v) > 0 { + return v[0] + } + } + return "" +} + +func checkClientVersion(ctx context.Context) error { + clientVersion := versionFromContext(ctx) + if !version.IsCompatible(clientVersion, version.Version) { + return fmt.Errorf("incompatible versions client: %s server: %s", clientVersion, version.Version) + } + return nil +}