Changes for two-way tls client validation

pull/11034/head
suryatech27-cloud 3 years ago
parent 657850e44b
commit a460b9ac6e

@ -61,6 +61,7 @@ func addChartPathOptionsFlags(f *pflag.FlagSet, c *action.ChartPathOptions) {
f.BoolVar(&c.InsecureSkipTLSverify, "insecure-skip-tls-verify", false, "skip tls certificate checks for the chart download") f.BoolVar(&c.InsecureSkipTLSverify, "insecure-skip-tls-verify", false, "skip tls certificate checks for the chart download")
f.StringVar(&c.CaFile, "ca-file", "", "verify certificates of HTTPS-enabled servers using this CA bundle") f.StringVar(&c.CaFile, "ca-file", "", "verify certificates of HTTPS-enabled servers using this CA bundle")
f.BoolVar(&c.PassCredentialsAll, "pass-credentials", false, "pass credentials to all domains") f.BoolVar(&c.PassCredentialsAll, "pass-credentials", false, "pass credentials to all domains")
f.BoolVar(&c.TlsEnabled, "tls-enabled", false, "if two-way tls authentication enabled then trying to send client certificate")
} }
// bindOutputFlag will add the output flag to the given command and bind the // bindOutputFlag will add the output flag to the given command and bind the

@ -68,6 +68,7 @@ Environment variables:
| $HELM_KUBECONTEXT | set the name of the kubeconfig context. | | $HELM_KUBECONTEXT | set the name of the kubeconfig context. |
| $HELM_KUBETOKEN | set the Bearer KubeToken used for authentication. | | $HELM_KUBETOKEN | set the Bearer KubeToken used for authentication. |
| $HELM_BURST_LIMIT | set the default burst limit in the case the server contains many CRDs (default 100, -1 to disable)| | $HELM_BURST_LIMIT | set the default burst limit in the case the server contains many CRDs (default 100, -1 to disable)|
| $HELM_CLIENT_TLS_CERT_DIR | set the certificate directory for 2-way tls support for oci pull. |
Helm stores cache, configuration, and data based on the following configuration order: Helm stores cache, configuration, and data based on the following configuration order:

@ -19,7 +19,14 @@ package tlsutil
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt"
"io/ioutil"
"log"
"os" "os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -56,3 +63,82 @@ func ClientConfig(opts Options) (cfg *tls.Config, err error) {
cfg = &tls.Config{InsecureSkipVerify: opts.InsecureSkipVerify, Certificates: []tls.Certificate{*cert}, RootCAs: pool} cfg = &tls.Config{InsecureSkipVerify: opts.InsecureSkipVerify, Certificates: []tls.Certificate{*cert}, RootCAs: pool}
return cfg, nil return cfg, nil
} }
func ReadCertFromSecDir(host string) (opts Options, err error) {
//fmt.Println("Final Host Name : ", host)
if runtime.GOOS == "windows" || runtime.GOOS == "unix" {
log.Fatalf("%v OS not supported for this oci pull.", runtime.GOOS)
os.Exit(1)
} else {
cmd, err := exec.Command("helm", "env", "HELM_CLIENT_TLS_CERT_DIR").Output()
if err != nil {
log.Fatalf("Error : %s", err)
os.Exit(1)
}
clientCertDir := strings.TrimSuffix(string(cmd), "\n")
if clientCertDir == "" {
log.Fatalf("Please configure client certificate directory for tls connection set/export HELM_CLIENT_TLS_CERT_DIR='/etc/docker/certs.d/'\n")
os.Exit(1)
}
if clientCertDir[len(clientCertDir)-1] != '/' {
clientCertDir = fmt.Sprintf("%s/%s", clientCertDir, host)
//fmt.Println("clientCertDir1", clientCertDir)
} else {
clientCertDir = fmt.Sprintf("%s%s", clientCertDir, host)
//fmt.Println("clientCertDir2", clientCertDir)
}
if _, err := os.Stat(clientCertDir); err != nil {
if os.IsNotExist(err) {
return opts, errors.Wrapf(err, clientCertDir, "%v\nPlease Create a directory same as hostname [%v] .")
}
} else {
if files, err := ioutil.ReadDir(clientCertDir); err == nil {
for _, file := range files {
if filepath.Ext(file.Name()) == ".pem" {
opts.CaCertFile = fmt.Sprintf("%s/%s", clientCertDir, file.Name())
//fmt.Println("Root ca file : ", opts.CaCertFile)
} else if filepath.Ext(file.Name()) == ".cert" {
opts.CertFile = fmt.Sprintf("%s/%s", clientCertDir, file.Name())
//fmt.Println("client cert file : ", opts.CertFile)
} else if filepath.Ext(file.Name()) == ".key" {
opts.KeyFile = fmt.Sprintf("%s/%s", clientCertDir, file.Name())
//fmt.Println("client key file", opts.KeyFile)
}
}
} else {
log.Fatalf(" Certificate not found in current directory - %v\n ", err)
os.Exit(1)
}
switch {
case opts.CaCertFile == "" && opts.CertFile == "" && opts.KeyFile == "":
fmt.Printf("Error : Missing certificate (cacerts.crt,client.pem,client.key) required !!\n")
os.Exit(1)
case opts.CaCertFile == "" && opts.CertFile == "":
fmt.Printf("Error : Missing certificate : Root-CA and client certificate (cacerts.crt,client.pem) required !!\n")
os.Exit(1)
case opts.CaCertFile == "" && opts.KeyFile == "":
fmt.Printf("Error : Missing Certificate : Root-CA and and client key (cacerts.crt,client.key) required.\n")
os.Exit(1)
case opts.CertFile == "" && opts.KeyFile == "":
fmt.Printf("Error : Missing Certificate : Client certificate and client key (client.pem,client.key) required.\n")
os.Exit(1)
}
switch {
case opts.CaCertFile == "":
fmt.Printf("Error : Missing Certificate : Client Root-CA (cacerts.crt) required.\n")
os.Exit(1)
case opts.CertFile == "":
fmt.Printf("Error : Missing Certificate : Client certificate(client.pem) required.\n")
os.Exit(1)
case opts.KeyFile == "":
fmt.Printf("Error : Missing Certificate : Client keyfile (client.key) required.\n")
os.Exit(1)
}
}
}
return opts, nil
}

@ -69,5 +69,9 @@ func ExtractHostname(addr string) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
if u.Port() != "" {
return u.Hostname() + ":" + u.Port(), nil
} else {
return u.Hostname(), nil return u.Hostname(), nil
} }
}

@ -117,6 +117,7 @@ type ChartPathOptions struct {
Username string // --username Username string // --username
Verify bool // --verify Verify bool // --verify
Version string // --version Version string // --version
TlsEnabled bool // --tls-enabled
// registryClient provides a registry client but is not added with // registryClient provides a registry client but is not added with
// options from a flag // options from a flag

@ -86,6 +86,7 @@ func (p *Pull) Run(chartRef string) (string, error) {
getter.WithPassCredentialsAll(p.PassCredentialsAll), getter.WithPassCredentialsAll(p.PassCredentialsAll),
getter.WithTLSClientConfig(p.CertFile, p.KeyFile, p.CaFile), getter.WithTLSClientConfig(p.CertFile, p.KeyFile, p.CaFile),
getter.WithInsecureSkipVerifyTLS(p.InsecureSkipTLSverify), getter.WithInsecureSkipVerifyTLS(p.InsecureSkipTLSverify),
getter.WithTwoWayTLSEnable(p.TlsEnabled),
}, },
RegistryClient: p.cfg.RegistryClient, RegistryClient: p.cfg.RegistryClient,
RepositoryConfig: p.Settings.RepositoryConfig, RepositoryConfig: p.Settings.RepositoryConfig,
@ -93,8 +94,25 @@ func (p *Pull) Run(chartRef string) (string, error) {
} }
if registry.IsOCI(chartRef) { if registry.IsOCI(chartRef) {
//fmt.Println("pull.go :===> tls enabled", p.TlsEnabled)
//fmt.Println("pull.go : ====> ", chartRef)
if !p.TlsEnabled {
c.Options = append(c.Options, c.Options = append(c.Options,
getter.WithRegistryClient(p.cfg.RegistryClient)) getter.WithRegistryClient(p.cfg.RegistryClient),
)
} else {
registryClient, err := registry.NewClient(
registry.ClientOptDebug(p.Settings.Debug),
registry.ClientOptCredentialsFile(p.Settings.RegistryConfig),
registry.ClientOptWriter(&out),
registry.ClientOptTwoWayTLSEnable(p.TlsEnabled),
registry.ClientOptChartRef(chartRef),
)
if err != nil {
return out.String(), err
}
c.Options = append(c.Options, getter.WithRegistryClient(registryClient))
}
} }
if p.Verify { if p.Verify {

@ -74,6 +74,8 @@ type EnvSettings struct {
MaxHistory int MaxHistory int
// BurstLimit is the default client-side throttling limit. // BurstLimit is the default client-side throttling limit.
BurstLimit int BurstLimit int
// Secondary Certificate directory for helm oci pull
ClientSecCertDirectory string
} }
func New() *EnvSettings { func New() *EnvSettings {
@ -86,6 +88,7 @@ func New() *EnvSettings {
KubeAsGroups: envCSV("HELM_KUBEASGROUPS"), KubeAsGroups: envCSV("HELM_KUBEASGROUPS"),
KubeAPIServer: os.Getenv("HELM_KUBEAPISERVER"), KubeAPIServer: os.Getenv("HELM_KUBEAPISERVER"),
KubeCaFile: os.Getenv("HELM_KUBECAFILE"), KubeCaFile: os.Getenv("HELM_KUBECAFILE"),
ClientSecCertDirectory: envOr("HELM_CLIENT_TLS_CERT_DIR", ""),
PluginsDirectory: envOr("HELM_PLUGINS", helmpath.DataPath("plugins")), PluginsDirectory: envOr("HELM_PLUGINS", helmpath.DataPath("plugins")),
RegistryConfig: envOr("HELM_REGISTRY_CONFIG", helmpath.ConfigPath("registry/config.json")), RegistryConfig: envOr("HELM_REGISTRY_CONFIG", helmpath.ConfigPath("registry/config.json")),
RepositoryConfig: envOr("HELM_REPOSITORY_CONFIG", helmpath.ConfigPath("repositories.yaml")), RepositoryConfig: envOr("HELM_REPOSITORY_CONFIG", helmpath.ConfigPath("repositories.yaml")),
@ -170,6 +173,7 @@ func (s *EnvSettings) EnvVars() map[string]string {
"HELM_NAMESPACE": s.Namespace(), "HELM_NAMESPACE": s.Namespace(),
"HELM_MAX_HISTORY": strconv.Itoa(s.MaxHistory), "HELM_MAX_HISTORY": strconv.Itoa(s.MaxHistory),
"HELM_BURST_LIMIT": strconv.Itoa(s.BurstLimit), "HELM_BURST_LIMIT": strconv.Itoa(s.BurstLimit),
"HELM_CLIENT_TLS_CERT_DIR": s.ClientSecCertDirectory,
// broken, these are populated from helm flags and not kubeconfig. // broken, these are populated from helm flags and not kubeconfig.
"HELM_KUBECONTEXT": s.KubeContext, "HELM_KUBECONTEXT": s.KubeContext,

@ -42,6 +42,7 @@ type options struct {
passCredentialsAll bool passCredentialsAll bool
userAgent string userAgent string
version string version string
tlsEnabled bool
registryClient *registry.Client registryClient *registry.Client
timeout time.Duration timeout time.Duration
transport *http.Transport transport *http.Transport
@ -87,6 +88,12 @@ func WithInsecureSkipVerifyTLS(insecureSkipVerifyTLS bool) Option {
} }
} }
func WithTwoWayTLSEnable(tlsEnabled bool) Option {
return func(opts *options) {
opts.tlsEnabled = tlsEnabled
}
}
// WithTLSClientConfig sets the client auth with the provided credentials. // WithTLSClientConfig sets the client auth with the provided credentials.
func WithTLSClientConfig(certFile, keyFile, caFile string) Option { func WithTLSClientConfig(certFile, keyFile, caFile string) Option {
return func(opts *options) { return func(opts *options) {

@ -22,9 +22,11 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"sort" "sort"
"strings" "strings"
"time"
"github.com/Masterminds/semver/v3" "github.com/Masterminds/semver/v3"
"github.com/containerd/containerd/remotes" "github.com/containerd/containerd/remotes"
@ -38,6 +40,8 @@ import (
registryremote "oras.land/oras-go/pkg/registry/remote" registryremote "oras.land/oras-go/pkg/registry/remote"
registryauth "oras.land/oras-go/pkg/registry/remote/auth" registryauth "oras.land/oras-go/pkg/registry/remote/auth"
"helm.sh/helm/v3/internal/tlsutil"
"helm.sh/helm/v3/internal/urlutil"
"helm.sh/helm/v3/internal/version" "helm.sh/helm/v3/internal/version"
"helm.sh/helm/v3/pkg/chart" "helm.sh/helm/v3/pkg/chart"
"helm.sh/helm/v3/pkg/helmpath" "helm.sh/helm/v3/pkg/helmpath"
@ -60,6 +64,8 @@ type (
authorizer auth.Client authorizer auth.Client
registryAuthorizer *registryauth.Client registryAuthorizer *registryauth.Client
resolver remotes.Resolver resolver remotes.Resolver
tlsEnabled bool
chartRef string
} }
// ClientOption allows specifying various settings configurable by the user for overriding the defaults // ClientOption allows specifying various settings configurable by the user for overriding the defaults
@ -85,7 +91,43 @@ func NewClient(options ...ClientOption) (*Client, error) {
} }
client.authorizer = authClient client.authorizer = authClient
} }
if client.resolver == nil { if client.resolver == nil {
if client.tlsEnabled {
host, err := urlutil.ExtractHostname(client.chartRef)
fmt.Println("host name : ", host)
if err != nil {
fmt.Printf("error :%v\n", err)
}
clientOpts, err := tlsutil.ReadCertFromSecDir(host)
if err != nil {
return client, errors.Wrapf(err, "Client certificate/directory Not Exist !!")
}
cfgtls, err := tlsutil.ClientConfig(clientOpts)
if err != nil {
fmt.Printf("error :%v\n", err)
}
var rt http.RoundTripper = &http.Transport{
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
TLSHandshakeTimeout: 30 * time.Second,
TLSClientConfig: cfgtls,
ResponseHeaderTimeout: time.Duration(30 * time.Second),
DisableKeepAlives: true,
}
crosclient := http.Client{Transport: rt, Timeout: 30 * time.Second}
headers := http.Header{}
headers.Set("User-Agent", version.GetUserAgent())
opts := []auth.ResolverOption{auth.WithResolverHeaders(headers), auth.WithResolverClient(&crosclient)}
resolver, err := client.authorizer.ResolverWithOpts(opts...)
if err != nil {
return nil, err
}
client.resolver = resolver
} else {
headers := http.Header{} headers := http.Header{}
headers.Set("User-Agent", version.GetUserAgent()) headers.Set("User-Agent", version.GetUserAgent())
opts := []auth.ResolverOption{auth.WithResolverHeaders(headers)} opts := []auth.ResolverOption{auth.WithResolverHeaders(headers)}
@ -95,6 +137,8 @@ func NewClient(options ...ClientOption) (*Client, error) {
} }
client.resolver = resolver client.resolver = resolver
} }
}
if client.registryAuthorizer == nil { if client.registryAuthorizer == nil {
client.registryAuthorizer = &registryauth.Client{ client.registryAuthorizer = &registryauth.Client{
Header: http.Header{ Header: http.Header{
@ -145,6 +189,12 @@ func ClientOptWriter(out io.Writer) ClientOption {
} }
} }
func ClientOptChartRef(chartRef string) ClientOption {
return func(client *Client) {
client.chartRef = chartRef
}
}
// ClientOptCredentialsFile returns a function that sets the credentialsFile setting on a client options set // ClientOptCredentialsFile returns a function that sets the credentialsFile setting on a client options set
func ClientOptCredentialsFile(credentialsFile string) ClientOption { func ClientOptCredentialsFile(credentialsFile string) ClientOption {
return func(client *Client) { return func(client *Client) {
@ -152,6 +202,13 @@ func ClientOptCredentialsFile(credentialsFile string) ClientOption {
} }
} }
//ClientOptTwoWayTLSEnable returns a function that sets the client certificate when two-way tls authentication enable
func ClientOptTwoWayTLSEnable(tlsEnabled bool) ClientOption {
return func(client *Client) {
client.tlsEnabled = tlsEnabled
}
}
type ( type (
// LoginOption allows specifying various settings on login // LoginOption allows specifying various settings on login
LoginOption func(*loginOperation) LoginOption func(*loginOperation)
@ -303,8 +360,9 @@ func (c *Client) Pull(ref string, options ...PullOption) (*PullResult, error) {
numDescriptors := len(descriptors) numDescriptors := len(descriptors)
if numDescriptors < minNumDescriptors { if numDescriptors < minNumDescriptors {
return nil, fmt.Errorf("manifest does not contain minimum number of descriptors (%d), descriptors found: %d", return nil, errors.New(
minNumDescriptors, numDescriptors) fmt.Sprintf("manifest does not contain minimum number of descriptors (%d), descriptors found: %d",
minNumDescriptors, numDescriptors))
} }
var configDescriptor *ocispec.Descriptor var configDescriptor *ocispec.Descriptor
var chartDescriptor *ocispec.Descriptor var chartDescriptor *ocispec.Descriptor
@ -324,19 +382,22 @@ func (c *Client) Pull(ref string, options ...PullOption) (*PullResult, error) {
} }
} }
if configDescriptor == nil { if configDescriptor == nil {
return nil, fmt.Errorf("could not load config with mediatype %s", ConfigMediaType) return nil, errors.New(
fmt.Sprintf("could not load config with mediatype %s", ConfigMediaType))
} }
if operation.withChart && chartDescriptor == nil { if operation.withChart && chartDescriptor == nil {
return nil, fmt.Errorf("manifest does not contain a layer with mediatype %s", return nil, errors.New(
ChartLayerMediaType) fmt.Sprintf("manifest does not contain a layer with mediatype %s",
ChartLayerMediaType))
} }
var provMissing bool var provMissing bool
if operation.withProv && provDescriptor == nil { if operation.withProv && provDescriptor == nil {
if operation.ignoreMissingProv { if operation.ignoreMissingProv {
provMissing = true provMissing = true
} else { } else {
return nil, fmt.Errorf("manifest does not contain a layer with mediatype %s", return nil, errors.New(
ProvLayerMediaType) fmt.Sprintf("manifest does not contain a layer with mediatype %s",
ProvLayerMediaType))
} }
} }
result := &PullResult{ result := &PullResult{

Loading…
Cancel
Save