2195 lines
66 KiB
Go
2195 lines
66 KiB
Go
/*
|
|
Copyright 2019 NetFoundry Inc.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package ziti
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/go-openapi/strfmt"
|
|
"github.com/google/uuid"
|
|
"github.com/kataras/go-events"
|
|
"github.com/openziti/edge-api/rest_client_api_client/authentication"
|
|
"github.com/openziti/edge-api/rest_client_api_client/service"
|
|
rest_session "github.com/openziti/edge-api/rest_client_api_client/session"
|
|
"github.com/openziti/foundation/v2/concurrenz"
|
|
"github.com/openziti/foundation/v2/errorz"
|
|
"github.com/openziti/foundation/v2/stringz"
|
|
apis "github.com/openziti/sdk-golang/edge-apis"
|
|
"github.com/openziti/secretstream/kx"
|
|
"math"
|
|
"math/rand"
|
|
"net"
|
|
"net/url"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/cenkalti/backoff/v4"
|
|
"github.com/michaelquigley/pfxlog"
|
|
"github.com/openziti/channel/v3"
|
|
"github.com/openziti/channel/v3/latency"
|
|
"github.com/openziti/edge-api/rest_client_api_client/current_api_session"
|
|
"github.com/openziti/edge-api/rest_model"
|
|
"github.com/openziti/foundation/v2/versions"
|
|
"github.com/openziti/identity"
|
|
"github.com/openziti/metrics"
|
|
"github.com/openziti/sdk-golang/ziti/edge"
|
|
"github.com/openziti/sdk-golang/ziti/edge/network"
|
|
"github.com/openziti/sdk-golang/ziti/signing"
|
|
"github.com/openziti/transport/v2"
|
|
cmap "github.com/orcaman/concurrent-map/v2"
|
|
"github.com/pkg/errors"
|
|
metrics2 "github.com/rcrowley/go-metrics"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
type SessionType rest_model.DialBind
|
|
|
|
const (
|
|
LatencyCheckInterval = 30 * time.Second
|
|
LatencyCheckTimeout = 10 * time.Second
|
|
|
|
ClientConfigV1 = "ziti-tunneler-client.v1"
|
|
InterceptV1 = "intercept.v1"
|
|
|
|
SessionDial = rest_model.DialBindDial
|
|
SessionBind = rest_model.DialBindBind
|
|
)
|
|
|
|
// MfaCodeResponse is a handler used to return a string (TOTP) code
|
|
type MfaCodeResponse func(code string) error
|
|
|
|
// Context is the main interface for SDK instances that may be used to authenticate, connect to services, or host
|
|
// services.
|
|
type Context interface {
|
|
// Authenticate attempts to use credentials configured on the Context to perform authentication. The authentication
|
|
// implementation used is configured via the Credentials field on an Option struct provided during Context
|
|
// creation.
|
|
Authenticate() error
|
|
|
|
// SetCredentials sets the credentials used to authenticate against the Edge Client API.
|
|
SetCredentials(authenticator apis.Credentials)
|
|
|
|
// GetCredentials returns the currently set credentials used to authenticate against the Edge Client API.
|
|
GetCredentials() apis.Credentials
|
|
|
|
// GetCurrentIdentity returns the Edge API details of the currently authenticated identity.
|
|
GetCurrentIdentity() (*rest_model.IdentityDetail, error)
|
|
|
|
// GetCurrentIdentityWithBackoff returns the Edge API details of the currently authenticated identity. with retry if necessary
|
|
GetCurrentIdentityWithBackoff() (*rest_model.IdentityDetail, error)
|
|
|
|
// Dial attempts to connect to a service using a given service name; authenticating as necessary in order to obtain
|
|
// a service session, attach to Edge Routers, and connect to a service.
|
|
Dial(serviceName string) (edge.Conn, error)
|
|
|
|
// DialWithOptions performs the same logic as Dial but allows specification of DialOptions.
|
|
DialWithOptions(serviceName string, options *DialOptions) (edge.Conn, error)
|
|
|
|
// DialAddr finds the service for given address and performs a Dial for it.
|
|
DialAddr(network string, addr string) (edge.Conn, error)
|
|
|
|
// Listen attempts to host a service by the given service name; authenticating as necessary in order to obtain
|
|
// a service session, attach to Edge Routers, and bind (host) the service.
|
|
Listen(serviceName string) (edge.Listener, error)
|
|
|
|
// ListenWithOptions performs the same logic as Listen, but allows the specification of ListenOptions.
|
|
ListenWithOptions(serviceName string, options *ListenOptions) (edge.Listener, error)
|
|
|
|
// GetServiceId will return the id of a specific service by service name. If not found, false, will be returned
|
|
// with an empty string.
|
|
GetServiceId(serviceName string) (string, bool, error)
|
|
|
|
// GetServices will return a slice of service details that the current authenticating identity can access for
|
|
// dial (connect) or bind (host/listen).
|
|
GetServices() ([]rest_model.ServiceDetail, error)
|
|
|
|
// GetService will return the service details of a specific service by service name.
|
|
GetService(serviceName string) (*rest_model.ServiceDetail, bool)
|
|
|
|
// GetServiceForAddr finds the service with intercept that matches best to given address
|
|
GetServiceForAddr(network, hostname string, port uint16) (*rest_model.ServiceDetail, int, error)
|
|
|
|
// RefreshServices forces the context to refresh the list of services the current authenticating identity has access
|
|
// to.
|
|
RefreshServices() error
|
|
|
|
// RefreshService forces the context to refresh just the service with the given name. If the given service isn't
|
|
// found, a nil will be returned
|
|
RefreshService(serviceName string) (*rest_model.ServiceDetail, error)
|
|
|
|
// GetServiceTerminators will return a slice of rest_model.TerminatorClientDetail for a specific service name.
|
|
// The offset and limit options can be used to page through excessive lists of items. A max of 500 is imposed on
|
|
// limit.
|
|
GetServiceTerminators(serviceName string, offset, limit int) ([]*rest_model.TerminatorClientDetail, int, error)
|
|
|
|
// GetSession will return the session detail associated with a specific session id.
|
|
GetSession(id string) (*rest_model.SessionDetail, error)
|
|
|
|
// Metrics will return the current context's metrics Registry.
|
|
Metrics() metrics.Registry
|
|
|
|
// Close closes any connections open to edge routers
|
|
Close()
|
|
|
|
// Deprecated: AddZitiMfaHandler adds a Ziti MFA handler, invoked during authentication.
|
|
// Replaced with event functionality. Use `zitiContext.Events().AddMfaTotpCodeListener(func(Context, *rest_model.AuthQueryDetail, MfaCodeResponse))` instead.
|
|
AddZitiMfaHandler(handler func(query *rest_model.AuthQueryDetail, resp MfaCodeResponse) error)
|
|
|
|
// EnrollZitiMfa will attempt to enable TOTP 2FA on the currently authenticating identity if not already enrolled.
|
|
EnrollZitiMfa() (*rest_model.DetailMfa, error)
|
|
|
|
// VerifyZitiMfa will attempt to complete enrollment of TOTP 2FA with the given code.
|
|
VerifyZitiMfa(code string) error
|
|
|
|
// RemoveZitiMfa will attempt to remove TOTP 2FA for the current identity
|
|
RemoveZitiMfa(code string) error
|
|
|
|
// GetId returns a unique context id
|
|
GetId() string
|
|
|
|
// SetId allows the setting of a context's id
|
|
SetId(id string)
|
|
|
|
Events() Eventer
|
|
}
|
|
|
|
var _ Context = &ContextImpl{}
|
|
|
|
type ContextImpl struct {
|
|
options *Options
|
|
Id string
|
|
routerConnections cmap.ConcurrentMap[string, edge.RouterConn]
|
|
|
|
CtrlClt *CtrlClient
|
|
|
|
services cmap.ConcurrentMap[string, *rest_model.ServiceDetail] // name -> Service
|
|
sessions cmap.ConcurrentMap[string, *rest_model.SessionDetail] // svcID:type -> Session
|
|
intercepts cmap.ConcurrentMap[string, *edge.InterceptV1Config]
|
|
|
|
metrics metrics.Registry
|
|
|
|
firstAuthOnce sync.Once
|
|
|
|
closed atomic.Bool
|
|
closeNotify chan struct{}
|
|
authQueryHandlers map[string]func(query *rest_model.AuthQueryDetail, response MfaCodeResponse) error
|
|
|
|
events.EventEmmiter
|
|
lastSuccessfulApiSessionRefresh time.Time
|
|
routerProxy func(addr string) *transport.ProxyConfiguration
|
|
}
|
|
|
|
func (context *ContextImpl) AddServiceAddedListener(handler func(Context, *rest_model.ServiceDetail)) func() {
|
|
listener := func(args ...interface{}) {
|
|
details, ok := args[0].(*rest_model.ServiceDetail)
|
|
|
|
if !ok {
|
|
pfxlog.Logger().Fatalf("could not convert args[0] to %T was %T", details, args[0])
|
|
}
|
|
|
|
if details == nil {
|
|
pfxlog.Logger().Fatalf("expected arg[0] was nil, unexpected")
|
|
}
|
|
|
|
handler(context, details)
|
|
}
|
|
|
|
context.AddListener(EventServiceAdded, listener)
|
|
|
|
return func() {
|
|
context.RemoveListener(EventServiceAdded, listener)
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) AddServiceChangedListener(handler func(Context, *rest_model.ServiceDetail)) func() {
|
|
listener := func(args ...interface{}) {
|
|
details, ok := args[0].(*rest_model.ServiceDetail)
|
|
|
|
if !ok {
|
|
pfxlog.Logger().Fatalf("could not convert args[0] to %T was %T", details, args[0])
|
|
}
|
|
|
|
if details == nil {
|
|
pfxlog.Logger().Fatalf("expected arg[0] was nil, unexpected")
|
|
}
|
|
|
|
handler(context, details)
|
|
}
|
|
|
|
context.AddListener(EventServiceChanged, listener)
|
|
|
|
return func() {
|
|
context.RemoveListener(EventServiceChanged, listener)
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) AddServiceRemovedListener(handler func(Context, *rest_model.ServiceDetail)) func() {
|
|
listener := func(args ...interface{}) {
|
|
details, ok := args[0].(*rest_model.ServiceDetail)
|
|
|
|
if !ok {
|
|
pfxlog.Logger().Fatalf("could not convert args[0] to %T was %T", details, args[0])
|
|
}
|
|
|
|
if details == nil {
|
|
pfxlog.Logger().Fatalf("expected arg[0] was nil, unexpected")
|
|
}
|
|
|
|
handler(context, details)
|
|
}
|
|
|
|
context.AddListener(EventServiceRemoved, listener)
|
|
|
|
return func() {
|
|
context.RemoveListener(EventServiceRemoved, listener)
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) AddRouterConnectedListener(handler func(Context, string, string)) func() {
|
|
listener := func(args ...interface{}) {
|
|
name, ok := args[0].(string)
|
|
|
|
if !ok {
|
|
pfxlog.Logger().Fatalf("could not convert args[0] to %T was %T", name, args[0])
|
|
}
|
|
|
|
addr, ok := args[1].(string)
|
|
|
|
if !ok {
|
|
pfxlog.Logger().Fatalf("could not convert args[1] to %T was %T", addr, args[1])
|
|
}
|
|
|
|
handler(context, name, addr)
|
|
}
|
|
|
|
context.AddListener(EventRouterConnected, listener)
|
|
|
|
return func() {
|
|
context.RemoveListener(EventRouterConnected, listener)
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) AddRouterDisconnectedListener(handler func(Context, string, string)) func() {
|
|
listener := func(args ...interface{}) {
|
|
name, ok := args[0].(string)
|
|
|
|
if !ok {
|
|
pfxlog.Logger().Fatalf("could not convert args[0] to %T was %T", name, args[0])
|
|
}
|
|
|
|
addr, ok := args[1].(string)
|
|
|
|
if !ok {
|
|
pfxlog.Logger().Fatalf("could not convert args[1] to %T was %T", addr, args[1])
|
|
}
|
|
|
|
handler(context, name, addr)
|
|
}
|
|
|
|
context.AddListener(EventRouterDisconnected, listener)
|
|
|
|
return func() {
|
|
context.RemoveListener(EventRouterDisconnected, listener)
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) AddMfaTotpCodeListener(handler func(Context, *rest_model.AuthQueryDetail, MfaCodeResponse)) func() {
|
|
listener := func(args ...interface{}) {
|
|
authQuery, ok := args[0].(*rest_model.AuthQueryDetail)
|
|
|
|
if !ok {
|
|
pfxlog.Logger().Fatalf("could not convert args[0] to %T was %T", authQuery, args[0])
|
|
}
|
|
|
|
if authQuery == nil {
|
|
pfxlog.Logger().Fatalf("expected arg[0] was nil, unexpected")
|
|
}
|
|
|
|
responder, ok := args[1].(MfaCodeResponse)
|
|
|
|
if !ok {
|
|
pfxlog.Logger().Fatalf("could not convert args[1] to %T was %T", responder, args[1])
|
|
}
|
|
|
|
if responder == nil {
|
|
pfxlog.Logger().Fatalf("expected arg[0] was nil, unexpected")
|
|
}
|
|
|
|
handler(context, authQuery, responder)
|
|
}
|
|
|
|
context.AddListener(EventMfaTotpCode, listener)
|
|
|
|
return func() {
|
|
context.RemoveListener(EventMfaTotpCode, listener)
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) AddAuthQueryListener(handler func(Context, *rest_model.AuthQueryDetail)) func() {
|
|
listener := func(args ...interface{}) {
|
|
authQuery, ok := args[0].(*rest_model.AuthQueryDetail)
|
|
|
|
if !ok {
|
|
pfxlog.Logger().Fatalf("could not convert args[0] to %T was %T", authQuery, args[0])
|
|
}
|
|
|
|
if authQuery == nil {
|
|
pfxlog.Logger().Fatalf("expected arg[0] was nil, unexpected")
|
|
}
|
|
|
|
handler(context, authQuery)
|
|
}
|
|
|
|
context.AddListener(EventAuthQuery, listener)
|
|
|
|
return func() {
|
|
context.RemoveListener(EventAuthQuery, listener)
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) AddAuthenticationStatePartialListener(handler func(Context, apis.ApiSession)) func() {
|
|
listener := func(args ...interface{}) {
|
|
apiSession, ok := args[0].(apis.ApiSession)
|
|
|
|
if !ok {
|
|
pfxlog.Logger().Fatalf("could not convert args[0] to %T was %T", apiSession, args[0])
|
|
}
|
|
|
|
if apiSession == nil {
|
|
pfxlog.Logger().Fatalf("expected arg[0] was nil, unexpected")
|
|
}
|
|
|
|
handler(context, apiSession)
|
|
}
|
|
|
|
context.AddListener(EventAuthenticationStatePartial, listener)
|
|
|
|
return func() {
|
|
context.RemoveListener(EventAuthenticationStatePartial, listener)
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) AddAuthenticationStateFullListener(handler func(Context, apis.ApiSession)) func() {
|
|
listener := func(args ...interface{}) {
|
|
apiSession, ok := args[0].(apis.ApiSession)
|
|
|
|
if !ok {
|
|
pfxlog.Logger().Fatalf("could not convert args[0] to %T was %T", apiSession, args[0])
|
|
}
|
|
|
|
if apiSession == nil {
|
|
pfxlog.Logger().Fatalf("expected arg[0] was nil, unexpected")
|
|
}
|
|
|
|
handler(context, apiSession)
|
|
}
|
|
|
|
context.AddListener(EventAuthenticationStateFull, listener)
|
|
|
|
return func() {
|
|
context.RemoveListener(EventAuthenticationStateFull, listener)
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) AddAuthenticationStateUnauthenticatedListener(handler func(Context, apis.ApiSession)) func() {
|
|
listener := func(args ...interface{}) {
|
|
var apiSession apis.ApiSession
|
|
|
|
if args[0] != nil {
|
|
var ok bool
|
|
apiSession, ok = args[0].(apis.ApiSession)
|
|
|
|
if !ok {
|
|
pfxlog.Logger().Fatalf("could not convert args[0] to %T was %T", apiSession, args[0])
|
|
}
|
|
}
|
|
|
|
handler(context, apiSession)
|
|
}
|
|
|
|
context.AddListener(EventAuthenticationStateUnauthenticated, listener)
|
|
|
|
return func() {
|
|
context.RemoveListener(EventAuthenticationStateUnauthenticated, listener)
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) AddControllerUrlsUpdateListener(handler func(Context, []*url.URL)) func() {
|
|
listener := func(args ...interface{}) {
|
|
var apiUrls []*url.URL
|
|
|
|
if args[0] != nil {
|
|
var ok bool
|
|
apiUrls, ok = args[0].([]*url.URL)
|
|
|
|
if !ok {
|
|
pfxlog.Logger().Fatalf("could not convert args[0] to %T was %T", apiUrls, args[0])
|
|
}
|
|
}
|
|
|
|
handler(context, apiUrls)
|
|
}
|
|
|
|
context.AddListener(EventControllerUrlsUpdated, listener)
|
|
|
|
return func() {
|
|
context.RemoveListener(EventAuthenticationStateUnauthenticated, listener)
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) Events() Eventer {
|
|
return context
|
|
}
|
|
|
|
func (context *ContextImpl) GetId() string {
|
|
return context.Id
|
|
}
|
|
|
|
func (context *ContextImpl) SetId(id string) {
|
|
context.Id = id
|
|
}
|
|
|
|
func (context *ContextImpl) SetCredentials(credentials apis.Credentials) {
|
|
context.CtrlClt.Credentials = credentials
|
|
}
|
|
|
|
func (context *ContextImpl) GetCredentials() apis.Credentials {
|
|
return context.CtrlClt.Credentials
|
|
}
|
|
|
|
func (context *ContextImpl) Sessions() ([]*rest_model.SessionDetail, error) {
|
|
var sessions []*rest_model.SessionDetail
|
|
context.sessions.IterCb(func(key string, s *rest_model.SessionDetail) {
|
|
sessions = append(sessions, s)
|
|
})
|
|
return sessions, nil
|
|
}
|
|
|
|
func (context *ContextImpl) OnClose(routerConn edge.RouterConn) {
|
|
logrus.Debugf("connection to router [%s] was closed", routerConn.Key())
|
|
context.Emit(EventRouterDisconnected, routerConn.GetRouterName(), routerConn.Key())
|
|
context.routerConnections.Remove(routerConn.Key())
|
|
}
|
|
|
|
func (context *ContextImpl) processServiceUpdates(services []*rest_model.ServiceDetail) {
|
|
pfxlog.Logger().Debugf("processing service updates with %v services", len(services))
|
|
|
|
idMap := make(map[string]*rest_model.ServiceDetail)
|
|
for _, s := range services {
|
|
idMap[*s.ID] = s
|
|
}
|
|
|
|
// process Deletes
|
|
var deletes []string
|
|
context.services.IterCb(func(key string, svc *rest_model.ServiceDetail) {
|
|
if _, found := idMap[*svc.ID]; !found {
|
|
deletes = append(deletes, key)
|
|
if context.options.OnServiceUpdate != nil {
|
|
context.options.OnServiceUpdate(ServiceRemoved, svc)
|
|
}
|
|
context.Emit(EventServiceRemoved, svc)
|
|
|
|
context.deleteServiceSessions(*svc.ID)
|
|
|
|
}
|
|
})
|
|
|
|
for _, deletedKey := range deletes {
|
|
context.services.Remove(deletedKey)
|
|
context.intercepts.Remove(deletedKey)
|
|
}
|
|
|
|
// Adds and Updates
|
|
for _, s := range services {
|
|
context.processServiceAddOrUpdated(s)
|
|
}
|
|
|
|
context.refreshServiceQueryMap()
|
|
}
|
|
|
|
func (context *ContextImpl) processSingleServiceUpdate(name string, s *rest_model.ServiceDetail) {
|
|
// process Deletes
|
|
if s == nil {
|
|
var deletes []string
|
|
context.services.IterCb(func(key string, svc *rest_model.ServiceDetail) {
|
|
if *svc.Name == name {
|
|
deletes = append(deletes, key)
|
|
if context.options.OnServiceUpdate != nil {
|
|
context.options.OnServiceUpdate(ServiceRemoved, svc)
|
|
}
|
|
context.Emit(EventServiceRemoved, svc)
|
|
context.deleteServiceSessions(*svc.ID)
|
|
}
|
|
})
|
|
|
|
for _, deletedKey := range deletes {
|
|
context.services.Remove(deletedKey)
|
|
context.intercepts.Remove(deletedKey)
|
|
}
|
|
} else {
|
|
// Adds and Updates
|
|
context.processServiceAddOrUpdated(s)
|
|
}
|
|
|
|
context.refreshServiceQueryMap()
|
|
}
|
|
|
|
func (context *ContextImpl) processServiceAddOrUpdated(s *rest_model.ServiceDetail) {
|
|
isChange := false
|
|
valuesDiffer := false
|
|
|
|
_ = context.services.Upsert(*s.Name, s, func(exist bool, valueInMap *rest_model.ServiceDetail, newValue *rest_model.ServiceDetail) *rest_model.ServiceDetail {
|
|
isChange = exist
|
|
if isChange {
|
|
valuesDiffer = !reflect.DeepEqual(newValue, valueInMap)
|
|
}
|
|
|
|
return newValue
|
|
})
|
|
|
|
if isChange {
|
|
context.Emit(EventServiceChanged, s)
|
|
} else {
|
|
context.Emit(EventServiceAdded, s)
|
|
}
|
|
|
|
if context.options.OnServiceUpdate != nil {
|
|
if isChange {
|
|
if valuesDiffer {
|
|
context.options.OnServiceUpdate(ServiceChanged, s)
|
|
}
|
|
} else {
|
|
context.services.Set(*s.Name, s)
|
|
context.options.OnServiceUpdate(ServiceAdded, s)
|
|
}
|
|
}
|
|
|
|
intercept := &edge.InterceptV1Config{}
|
|
ok, err := edge.ParseServiceConfig(s, InterceptV1, intercept)
|
|
if err != nil {
|
|
pfxlog.Logger().Warnf("failed to parse config[%s] for service[%s]", InterceptV1, *s.Name)
|
|
} else if ok {
|
|
intercept.Service = s
|
|
context.intercepts.Set(*s.Name, intercept)
|
|
} else {
|
|
cltCfg := &edge.ClientConfig{}
|
|
ok, err := edge.ParseServiceConfig(s, ClientConfigV1, cltCfg)
|
|
if err == nil && ok {
|
|
intercept = cltCfg.ToInterceptV1Config()
|
|
intercept.Service = s
|
|
context.intercepts.Set(*s.Name, intercept)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) refreshServiceQueryMap() {
|
|
serviceQueryMap := map[string]map[string]rest_model.PostureQuery{} //serviceId -> queryId -> query
|
|
|
|
context.services.IterCb(func(key string, svc *rest_model.ServiceDetail) {
|
|
for _, querySets := range svc.PostureQueries {
|
|
for _, query := range querySets.PostureQueries {
|
|
var queryMap map[string]rest_model.PostureQuery
|
|
var ok bool
|
|
if queryMap, ok = serviceQueryMap[*svc.ID]; !ok {
|
|
queryMap = map[string]rest_model.PostureQuery{}
|
|
serviceQueryMap[*svc.ID] = queryMap
|
|
}
|
|
queryMap[*query.ID] = *query
|
|
}
|
|
}
|
|
})
|
|
|
|
context.CtrlClt.PostureCache.SetServiceQueryMap(serviceQueryMap)
|
|
}
|
|
|
|
func (context *ContextImpl) refreshSessions() {
|
|
log := pfxlog.Logger()
|
|
edgeRouters := make(map[string]string)
|
|
var toDelete []string
|
|
for entry := range context.sessions.IterBuffered() {
|
|
key := entry.Key
|
|
session := entry.Val
|
|
log.Debugf("refreshing session for %s", key)
|
|
|
|
if s, err := context.refreshSession(session); err != nil {
|
|
log.WithError(err).Errorf("failed to refresh session for %s", key)
|
|
toDelete = append(toDelete, *session.ID)
|
|
} else {
|
|
for _, er := range s.EdgeRouters {
|
|
for _, u := range er.SupportedProtocols {
|
|
if context.options.isEdgeRouterUrlAccepted(u) {
|
|
edgeRouters[u] = *er.Name
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
for _, id := range toDelete {
|
|
context.sessions.Remove(id)
|
|
}
|
|
|
|
for u, name := range edgeRouters {
|
|
go context.handleConnectEdgeRouter(name, u, nil)
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) RefreshServices() error {
|
|
return context.refreshServices(true)
|
|
}
|
|
|
|
func (context *ContextImpl) refreshServices(forceCheck bool) error {
|
|
if err := context.ensureApiSession(); err != nil {
|
|
return fmt.Errorf("failed to refresh services: %v", err)
|
|
}
|
|
|
|
var checkService bool
|
|
var lastServiceUpdate *strfmt.DateTime
|
|
var err error
|
|
|
|
log := pfxlog.Logger()
|
|
log.Debug("checking if service updates available")
|
|
if checkService, lastServiceUpdate, err = context.CtrlClt.IsServiceListUpdateAvailable(); err != nil {
|
|
log.WithError(err).Error("failed to check if service list update is available")
|
|
target := ¤t_api_session.ListServiceUpdatesUnauthorized{}
|
|
if errors.As(err, &target) {
|
|
checkService = true
|
|
} else {
|
|
if err = context.Authenticate(); err != nil {
|
|
log.WithError(err).Error("unable to re-authenticate during session refresh")
|
|
} else {
|
|
if checkService, lastServiceUpdate, err = context.CtrlClt.IsServiceListUpdateAvailable(); err != nil {
|
|
checkService = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if checkService || forceCheck {
|
|
log.Debug("refreshing services")
|
|
|
|
services, err := context.CtrlClt.GetServices()
|
|
if err != nil {
|
|
target := &service.ListServicesUnauthorized{}
|
|
if errors.As(err, &target) {
|
|
log.Info("attempting to re-authenticate")
|
|
if authErr := context.Authenticate(); authErr != nil {
|
|
log.WithError(authErr).Error("unable to re-authenticate during services refresh")
|
|
return err
|
|
}
|
|
if services, err = context.CtrlClt.GetServices(); err != nil {
|
|
return err
|
|
}
|
|
|
|
} else {
|
|
return err
|
|
}
|
|
}
|
|
context.CtrlClt.lastServiceUpdate = lastServiceUpdate
|
|
context.processServiceUpdates(services)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (context *ContextImpl) RefreshService(serviceName string) (*rest_model.ServiceDetail, error) {
|
|
if err := context.ensureApiSession(); err != nil {
|
|
return nil, fmt.Errorf("failed to refresh service: %v", err)
|
|
}
|
|
|
|
var err error
|
|
|
|
log := pfxlog.Logger().WithField("serviceName", serviceName)
|
|
|
|
log.Debug("refreshing service")
|
|
|
|
serviceDetail, err := context.CtrlClt.GetService(serviceName)
|
|
if err != nil {
|
|
target := &service.ListServicesUnauthorized{}
|
|
if errors.As(err, &target) {
|
|
log.Info("attempting to re-authenticate")
|
|
if authErr := context.Authenticate(); authErr != nil {
|
|
log.WithError(authErr).Error("unable to re-authenticate during service refresh")
|
|
return nil, err
|
|
}
|
|
if serviceDetail, err = context.CtrlClt.GetService(serviceName); err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
context.processSingleServiceUpdate(serviceName, serviceDetail)
|
|
|
|
return serviceDetail, nil
|
|
}
|
|
|
|
func (context *ContextImpl) updateTokenOnAllErs(apiSession apis.ApiSession) {
|
|
if apiSession.RequiresRouterTokenUpdate() {
|
|
for tpl := range context.routerConnections.IterBuffered() {
|
|
erConn := tpl.Val
|
|
erKey := tpl.Key
|
|
go func() {
|
|
if err := erConn.UpdateToken(apiSession.GetToken(), 10*time.Second); err != nil {
|
|
pfxlog.Logger().WithError(err).WithField("er", erKey).Warn("error updating apiSession token to connected ER")
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) runRefreshes() {
|
|
log := pfxlog.Logger()
|
|
svcRefreshInterval := context.options.RefreshInterval
|
|
|
|
if svcRefreshInterval == 0 {
|
|
svcRefreshInterval = DefaultServiceRefreshInterval
|
|
}
|
|
if svcRefreshInterval < MinRefreshInterval {
|
|
svcRefreshInterval = MinRefreshInterval
|
|
}
|
|
svcRefreshTick := time.NewTicker(svcRefreshInterval)
|
|
defer svcRefreshTick.Stop()
|
|
|
|
sessionRefreshInterval := context.options.SessionRefreshInterval
|
|
if sessionRefreshInterval == 0 {
|
|
sessionRefreshInterval = DefaultSessionRefreshInterval
|
|
}
|
|
if sessionRefreshInterval < MinRefreshInterval {
|
|
sessionRefreshInterval = MinRefreshInterval
|
|
}
|
|
|
|
sessionRefreshTick := time.NewTicker(sessionRefreshInterval)
|
|
defer sessionRefreshTick.Stop()
|
|
|
|
refreshAt := time.Now().Add(30 * time.Second)
|
|
|
|
if currentApiSession := context.CtrlClt.GetCurrentApiSession(); currentApiSession != nil && currentApiSession.GetExpiresAt() != nil {
|
|
refreshAt = (*currentApiSession.GetExpiresAt()).Add(-10 * time.Second)
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-context.closeNotify:
|
|
return
|
|
|
|
case <-time.After(time.Until(refreshAt)):
|
|
apiSession := context.CtrlClt.GetCurrentApiSession()
|
|
|
|
if apiSession == nil {
|
|
pfxlog.Logger().Warn("could not refresh api session, current api session is nil, authenticating")
|
|
if err := context.Authenticate(); err != nil {
|
|
pfxlog.Logger().WithError(err).Error("failed to authenticate")
|
|
}
|
|
refreshAt = time.Now().Add(5 * time.Second)
|
|
continue
|
|
}
|
|
|
|
newApiSession, err := context.CtrlClt.Refresh()
|
|
|
|
if err != nil {
|
|
log.Errorf("could not refresh apiSession: %v", err)
|
|
|
|
refreshAt = time.Now().Add(5 * time.Second)
|
|
} else {
|
|
exp := newApiSession.GetExpiresAt()
|
|
refreshAt = exp.Add(-10 * time.Second)
|
|
log.Debugf("apiSession refreshed, new expiration[%s]", *exp)
|
|
|
|
context.updateTokenOnAllErs(newApiSession)
|
|
}
|
|
|
|
case <-svcRefreshTick.C:
|
|
log.Debug("refreshing services")
|
|
if err := context.refreshServices(false); err != nil {
|
|
log.WithError(err).Error("failed to load service updates")
|
|
}
|
|
|
|
case <-sessionRefreshTick.C:
|
|
log.Debug("refreshing sessions")
|
|
context.refreshSessions()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) EnsureAuthenticated(options edge.ConnOptions) error {
|
|
operation := func() error {
|
|
pfxlog.Logger().Info("attempting to establish new api session")
|
|
err := context.Authenticate()
|
|
if err != nil {
|
|
return backoff.Permanent(err)
|
|
}
|
|
|
|
return err
|
|
}
|
|
expBackoff := backoff.NewExponentialBackOff()
|
|
expBackoff.MaxInterval = 10 * time.Second
|
|
expBackoff.MaxElapsedTime = options.GetConnectTimeout()
|
|
|
|
return backoff.Retry(operation, expBackoff)
|
|
}
|
|
|
|
func (context *ContextImpl) GetCurrentIdentity() (*rest_model.IdentityDetail, error) {
|
|
if err := context.ensureApiSession(); err != nil {
|
|
return nil, errors.Wrap(err, "failed to establish api session")
|
|
}
|
|
|
|
return context.CtrlClt.GetCurrentIdentity()
|
|
}
|
|
|
|
func (context *ContextImpl) GetCurrentIdentityWithBackoff() (*rest_model.IdentityDetail, error) {
|
|
expBackoff := backoff.NewExponentialBackOff()
|
|
|
|
expBackoff.InitialInterval = time.Second
|
|
expBackoff.MaxInterval = 30 * time.Second
|
|
expBackoff.MaxElapsedTime = 5 * time.Minute
|
|
|
|
var detail *rest_model.IdentityDetail
|
|
operation := func() error {
|
|
var err error
|
|
detail, err = context.GetCurrentIdentity()
|
|
return err
|
|
}
|
|
|
|
if err := backoff.Retry(operation, expBackoff); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return detail, nil
|
|
}
|
|
|
|
func (context *ContextImpl) setUnauthenticated() {
|
|
prevApiSessionPtr := context.CtrlClt.ApiSession.Swap(nil)
|
|
willEmit := prevApiSessionPtr != nil
|
|
|
|
context.CtrlClt.ApiSessionCertificate = nil
|
|
|
|
context.CloseAllEdgeRouterConns()
|
|
context.sessions.Clear()
|
|
|
|
if willEmit {
|
|
context.Emit(EventAuthenticationStateUnauthenticated, *prevApiSessionPtr)
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) authenticate() error {
|
|
logrus.Debug("attempting to authenticate")
|
|
context.services = cmap.New[*rest_model.ServiceDetail]()
|
|
context.sessions = cmap.New[*rest_model.SessionDetail]()
|
|
context.intercepts = cmap.New[*edge.InterceptV1Config]()
|
|
|
|
context.setUnauthenticated()
|
|
|
|
apiSession, err := context.CtrlClt.Authenticate()
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
authQueries := apiSession.GetAuthQueries()
|
|
if len(authQueries) != 0 {
|
|
context.Emit(EventAuthenticationStatePartial, apiSession)
|
|
for _, authQuery := range apiSession.GetAuthQueries() {
|
|
if err := context.handleAuthQuery(authQuery); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
return context.onFullAuth(apiSession)
|
|
}
|
|
|
|
func (context *ContextImpl) Reauthenticate() error {
|
|
context.CtrlClt.ApiSession.Store(nil)
|
|
context.CtrlClt.ApiSessionCertificate = nil
|
|
|
|
return context.authenticate()
|
|
}
|
|
|
|
func (context *ContextImpl) Authenticate() error {
|
|
if context.CtrlClt.GetCurrentApiSession() != nil {
|
|
if time.Since(context.lastSuccessfulApiSessionRefresh) < 5*time.Second {
|
|
return nil
|
|
}
|
|
logrus.Debug("previous apiSession detected, checking if valid")
|
|
if err := context.RefreshApiSessionWithBackoff(); err == nil {
|
|
logrus.Info("previous apiSession refreshed")
|
|
context.lastSuccessfulApiSessionRefresh = time.Now()
|
|
return nil
|
|
} else {
|
|
logrus.WithError(err).Info("previous apiSession failed to refresh, attempting to authenticate")
|
|
}
|
|
}
|
|
|
|
return context.authenticate()
|
|
}
|
|
|
|
func (context *ContextImpl) RefreshApiSessionWithBackoff() error {
|
|
expBackoff := backoff.NewExponentialBackOff()
|
|
|
|
expBackoff.InitialInterval = 5 * time.Second
|
|
expBackoff.MaxInterval = 5 * time.Minute
|
|
expBackoff.MaxElapsedTime = 24 * time.Hour
|
|
|
|
operation := func() error {
|
|
newApiSession, err := context.CtrlClt.Refresh()
|
|
if err == nil {
|
|
context.updateTokenOnAllErs(newApiSession)
|
|
return nil
|
|
}
|
|
|
|
unauthorizedErr := ¤t_api_session.GetCurrentAPISessionUnauthorized{}
|
|
if errors.As(err, &unauthorizedErr) {
|
|
logrus.Info("previous apiSession expired")
|
|
return backoff.Permanent(err)
|
|
}
|
|
logrus.WithError(err).Info("unable to refresh apiSession, will retry")
|
|
return err
|
|
}
|
|
|
|
return backoff.Retry(operation, expBackoff)
|
|
}
|
|
|
|
func (context *ContextImpl) CloseAllEdgeRouterConns() {
|
|
for entry := range context.routerConnections.IterBuffered() {
|
|
key, val := entry.Key, entry.Val
|
|
if !val.IsClosed() {
|
|
if err := val.Close(); err != nil {
|
|
pfxlog.Logger().WithError(err).Error("error while closing edge router connection")
|
|
}
|
|
}
|
|
|
|
context.routerConnections.Remove(key)
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) onFullAuth(apiSession apis.ApiSession) error {
|
|
var doOnceErr error
|
|
context.firstAuthOnce.Do(func() {
|
|
if context.options.OnContextReady != nil {
|
|
context.options.OnContextReady(context)
|
|
}
|
|
go context.runRefreshes()
|
|
|
|
metricsTags := map[string]string{
|
|
"srcId": apiSession.GetIdentityId(),
|
|
}
|
|
|
|
context.metrics = metrics.NewRegistry(apiSession.GetIdentityName(), metricsTags)
|
|
})
|
|
|
|
context.Emit(EventAuthenticationStateFull, apiSession)
|
|
|
|
// get services
|
|
if err := context.RefreshServices(); err != nil {
|
|
doOnceErr = err
|
|
}
|
|
|
|
return doOnceErr
|
|
}
|
|
|
|
func (context *ContextImpl) AddZitiMfaHandler(handler func(query *rest_model.AuthQueryDetail, response MfaCodeResponse) error) {
|
|
context.authQueryHandlers[string(rest_model.MfaProvidersZiti)] = handler
|
|
}
|
|
|
|
func (context *ContextImpl) authenticateMfa(code string) error {
|
|
if err := context.CtrlClt.AuthenticateMFA(code); err != nil {
|
|
return err
|
|
}
|
|
|
|
newApiSession, err := context.CtrlClt.Refresh()
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
context.updateTokenOnAllErs(newApiSession)
|
|
|
|
apiSession := context.CtrlClt.GetCurrentApiSession()
|
|
|
|
if apiSession != nil && len(apiSession.GetAuthQueries()) == 0 {
|
|
return context.onFullAuth(apiSession)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (context *ContextImpl) handleAuthQuery(authQuery *rest_model.AuthQueryDetail) error {
|
|
context.Emit(EventAuthQuery, authQuery)
|
|
|
|
if authQuery.Provider == nil {
|
|
return fmt.Errorf("unhandled response from controller: authentication query has no provider specified")
|
|
}
|
|
|
|
if *authQuery.Provider == rest_model.MfaProvidersZiti {
|
|
handler := context.authQueryHandlers[string(rest_model.MfaProvidersZiti)]
|
|
|
|
context.Emit(EventMfaTotpCode, authQuery, MfaCodeResponse(context.authenticateMfa))
|
|
|
|
if handler == nil {
|
|
pfxlog.Logger().Debugf("no callback handler registered for provider: %v, event will still be emitted", *authQuery.Provider)
|
|
} else {
|
|
return handler(authQuery, context.authenticateMfa)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
return fmt.Errorf("unsupported MFA provider: %v", *authQuery.Provider)
|
|
}
|
|
|
|
func (context *ContextImpl) Dial(serviceName string) (edge.Conn, error) {
|
|
defaultOptions := &DialOptions{ConnectTimeout: 5 * time.Second}
|
|
return context.DialWithOptions(serviceName, defaultOptions)
|
|
}
|
|
|
|
func (context *ContextImpl) DialWithOptions(serviceName string, options *DialOptions) (edge.Conn, error) {
|
|
edgeDialOptions := &edge.DialOptions{
|
|
ConnectTimeout: options.ConnectTimeout,
|
|
Identity: options.Identity,
|
|
AppData: options.AppData,
|
|
StickinessToken: options.StickinessToken,
|
|
}
|
|
if edgeDialOptions.GetConnectTimeout() == 0 {
|
|
edgeDialOptions.ConnectTimeout = 15 * time.Second
|
|
}
|
|
|
|
if err := context.ensureApiSession(); err != nil {
|
|
return nil, fmt.Errorf("failed to dial: %v", err)
|
|
}
|
|
|
|
svc, ok := context.GetService(serviceName)
|
|
if !ok {
|
|
return nil, errors.Errorf("service '%s' not found", serviceName)
|
|
}
|
|
|
|
context.CtrlClt.PostureCache.AddActiveService(*svc.ID)
|
|
|
|
edgeDialOptions.CallerId = context.CtrlClt.GetCurrentApiSession().GetIdentityName()
|
|
|
|
session, err := context.GetSession(*svc.ID)
|
|
if err != nil {
|
|
context.deleteServiceSessions(*svc.ID)
|
|
if session, err = context.createSessionWithBackoff(svc, SessionType(SessionDial), options); err != nil {
|
|
return nil, errors.Wrapf(err, "unable to dial service '%v'", serviceName)
|
|
}
|
|
}
|
|
|
|
pfxlog.Logger().WithField("sessionId", *session.ID).WithField("sessionToken", session.Token).Debug("connecting with session")
|
|
conn, err := context.dialSession(svc, session, edgeDialOptions)
|
|
if err == nil {
|
|
return conn, nil
|
|
}
|
|
|
|
var refreshErr error
|
|
if _, refreshErr = context.refreshSession(session); refreshErr == nil {
|
|
// if the session wasn't expired, no reason to try again, return the failure
|
|
return nil, errors.Wrapf(err, "unable to dial service '%s'", serviceName)
|
|
}
|
|
|
|
context.deleteServiceSessions(*svc.ID)
|
|
if session, refreshErr = context.createSessionWithBackoff(svc, SessionType(SessionDial), options); refreshErr != nil {
|
|
// couldn't create a new session, report the error
|
|
return nil, errors.Wrapf(refreshErr, "unable to dial service '%s'", serviceName)
|
|
}
|
|
|
|
// retry with new session
|
|
conn, err = context.dialSession(svc, session, edgeDialOptions)
|
|
if err == nil {
|
|
return conn, nil
|
|
}
|
|
|
|
return nil, errors.Wrapf(err, "unable to dial service '%s'", serviceName)
|
|
}
|
|
|
|
// GetServiceForAddr finds the service with intercept that matches best to given address
|
|
func (context *ContextImpl) GetServiceForAddr(network, hostname string, port uint16) (*rest_model.ServiceDetail, int, error) {
|
|
var svc *rest_model.ServiceDetail
|
|
score := math.MaxInt
|
|
lowestFound := false
|
|
context.intercepts.IterCb(func(key string, intercept *edge.InterceptV1Config) {
|
|
if lowestFound {
|
|
return
|
|
}
|
|
|
|
sc := intercept.Match(network, hostname, port)
|
|
if sc != -1 {
|
|
if score > sc {
|
|
score = sc
|
|
svc = intercept.Service
|
|
} else if score == sc && *intercept.Service.Name < *svc.Name { // if score is the same, pick alphabetically first service
|
|
score = sc
|
|
svc = intercept.Service
|
|
}
|
|
|
|
if sc == 0 {
|
|
lowestFound = true
|
|
}
|
|
}
|
|
})
|
|
|
|
if svc == nil {
|
|
return nil, -1, errors.Errorf("no service for address[%s:%s:%d]", network, hostname, port)
|
|
}
|
|
|
|
return svc, score, nil
|
|
}
|
|
|
|
func (context *ContextImpl) dialServiceFromAddr(service, network, host string, port uint16) (edge.Conn, error) {
|
|
appdata := make(map[string]any)
|
|
appdata["dst_protocol"] = network
|
|
appdata["dst_port"] = strconv.Itoa(int(port))
|
|
ip := net.ParseIP(host)
|
|
if len(ip) != 0 {
|
|
appdata["dst_ip"] = host
|
|
} else {
|
|
appdata["dst_hostname"] = host
|
|
}
|
|
|
|
options := &DialOptions{
|
|
ConnectTimeout: 5 * time.Second,
|
|
}
|
|
appdataJson, _ := json.Marshal(appdata)
|
|
options.AppData = appdataJson
|
|
|
|
return context.DialWithOptions(service, options)
|
|
}
|
|
|
|
func (context *ContextImpl) DialAddr(network string, addr string) (edge.Conn, error) {
|
|
host, portStr, err := net.SplitHostPort(addr)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
port, err := strconv.Atoi(portStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
network = normalizeProtocol(network)
|
|
|
|
svc, _, err := context.GetServiceForAddr(network, host, uint16(port))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return context.dialServiceFromAddr(*svc.Name, network, host, uint16(port))
|
|
}
|
|
|
|
func (context *ContextImpl) dialSession(service *rest_model.ServiceDetail, session *rest_model.SessionDetail, options *edge.DialOptions) (edge.Conn, error) {
|
|
edgeConnFactory, err := context.getEdgeRouterConn(session, options)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return edgeConnFactory.Connect(service, session, options)
|
|
}
|
|
|
|
func (context *ContextImpl) ensureApiSession() error {
|
|
if context.CtrlClt.GetCurrentApiSession() == nil {
|
|
if err := context.Authenticate(); err != nil {
|
|
return fmt.Errorf("no apiSession, authentication attempt failed: %v", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (context *ContextImpl) Listen(serviceName string) (edge.Listener, error) {
|
|
return context.ListenWithOptions(serviceName, DefaultListenOptions())
|
|
}
|
|
|
|
func (context *ContextImpl) ListenWithOptions(serviceName string, options *ListenOptions) (edge.Listener, error) {
|
|
if err := context.ensureApiSession(); err != nil {
|
|
return nil, fmt.Errorf("failed to listen: %v", err)
|
|
}
|
|
|
|
if s, ok := context.GetService(serviceName); ok {
|
|
return context.listenSession(s, options)
|
|
}
|
|
return nil, errors.Errorf("service '%s' not found in ziti network", serviceName)
|
|
}
|
|
|
|
func (context *ContextImpl) listenSession(service *rest_model.ServiceDetail, options *ListenOptions) (edge.Listener, error) {
|
|
edgeListenOptions := edge.NewListenOptions()
|
|
edgeListenOptions.Cost = options.Cost
|
|
edgeListenOptions.Precedence = edge.Precedence(options.Precedence)
|
|
edgeListenOptions.ConnectTimeout = options.ConnectTimeout
|
|
if options.MaxTerminators != 0 {
|
|
edgeListenOptions.MaxTerminators = options.MaxTerminators
|
|
} else {
|
|
edgeListenOptions.MaxTerminators = options.MaxConnections
|
|
}
|
|
edgeListenOptions.Identity = options.Identity
|
|
edgeListenOptions.BindUsingEdgeIdentity = options.BindUsingEdgeIdentity
|
|
edgeListenOptions.ManualStart = options.ManualStart
|
|
|
|
if edgeListenOptions.ConnectTimeout == 0 {
|
|
edgeListenOptions.ConnectTimeout = time.Minute
|
|
}
|
|
|
|
if edgeListenOptions.MaxTerminators < 1 {
|
|
edgeListenOptions.MaxTerminators = 1
|
|
}
|
|
|
|
if listenerMgr, err := newListenerManager(service, context, edgeListenOptions, options.WaitForNEstablishedListeners); err != nil {
|
|
return nil, err
|
|
} else {
|
|
return listenerMgr.listener, nil
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) getEdgeRouterConn(session *rest_model.SessionDetail, options edge.ConnOptions) (edge.RouterConn, error) {
|
|
logger := pfxlog.Logger().WithField("sessionId", *session.ID)
|
|
|
|
if len(session.EdgeRouters) == 0 {
|
|
if refreshedSession, err := context.refreshSession(session); err != nil {
|
|
target := &rest_session.DetailSessionNotFound{}
|
|
if errors.As(err, &target) {
|
|
sessionKey := fmt.Sprintf("%s:%s", session.Service.ID, *session.Type)
|
|
context.sessions.Remove(sessionKey)
|
|
}
|
|
|
|
return nil, fmt.Errorf("no edge routers available, refresh errored: %v", err)
|
|
} else {
|
|
if len(refreshedSession.EdgeRouters) == 0 {
|
|
return nil, errors.New("no edge routers available, refresh yielded no new edge routers")
|
|
}
|
|
|
|
session = refreshedSession
|
|
}
|
|
}
|
|
|
|
// go through connected routers first
|
|
bestLatency := time.Duration(math.MaxInt64)
|
|
var bestER edge.RouterConn
|
|
var unconnected []*rest_model.SessionEdgeRouter
|
|
for _, edgeRouter := range session.EdgeRouters {
|
|
for proto, addr := range edgeRouter.SupportedProtocols {
|
|
addr = strings.Replace(addr, "://", ":", 1)
|
|
edgeRouter.SupportedProtocols[proto] = addr
|
|
if er, found := context.routerConnections.Get(addr); found {
|
|
h := context.metrics.Histogram("latency." + addr).(metrics2.Histogram)
|
|
if h.Mean() < float64(bestLatency) {
|
|
bestLatency = time.Duration(int64(h.Mean()))
|
|
bestER = er
|
|
}
|
|
} else {
|
|
unconnected = append(unconnected, edgeRouter)
|
|
}
|
|
}
|
|
}
|
|
|
|
var ch chan *edgeRouterConnResult
|
|
if bestER == nil {
|
|
ch = make(chan *edgeRouterConnResult, len(unconnected))
|
|
}
|
|
|
|
for _, edgeRouter := range unconnected {
|
|
for _, addr := range edgeRouter.SupportedProtocols {
|
|
if context.options.isEdgeRouterUrlAccepted(addr) {
|
|
go context.handleConnectEdgeRouter(*edgeRouter.Name, addr, ch)
|
|
}
|
|
}
|
|
}
|
|
|
|
if bestER != nil {
|
|
logger.Debugf("selected router[%s@%s] for best latency(%d ms)",
|
|
bestER.GetRouterName(), bestER.Key(), bestLatency.Milliseconds())
|
|
return bestER, nil
|
|
}
|
|
|
|
timeout := time.After(options.GetConnectTimeout())
|
|
for {
|
|
select {
|
|
case f := <-ch:
|
|
if f.routerConnection != nil {
|
|
logger.Debugf("using edgeRouter[%s]", f.routerConnection.Key())
|
|
return f.routerConnection, nil
|
|
}
|
|
case <-timeout:
|
|
return nil, errors.New("no edge routers connected in time")
|
|
}
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) handleConnectEdgeRouter(routerName, ingressUrl string, ret chan *edgeRouterConnResult) {
|
|
result := context.connectEdgeRouter(routerName, ingressUrl)
|
|
|
|
if ret != nil {
|
|
select {
|
|
case ret <- result:
|
|
case <-time.After(10 * time.Second):
|
|
}
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) connectEdgeRouter(routerName, ingressUrl string) *edgeRouterConnResult {
|
|
logger := pfxlog.Logger().WithField("router", routerName)
|
|
|
|
if conn, found := context.routerConnections.Get(ingressUrl); found {
|
|
if !conn.IsClosed() {
|
|
return &edgeRouterConnResult{
|
|
routerUrl: ingressUrl,
|
|
routerName: routerName,
|
|
routerConnection: conn,
|
|
}
|
|
} else {
|
|
context.routerConnections.Remove(ingressUrl)
|
|
}
|
|
}
|
|
|
|
ingAddr, err := transport.ParseAddress(ingressUrl)
|
|
if err != nil {
|
|
logger.WithError(err).Errorf("failed to parse url[%s]", ingressUrl)
|
|
return &edgeRouterConnResult{
|
|
routerUrl: ingressUrl,
|
|
routerName: routerName,
|
|
err: err,
|
|
}
|
|
}
|
|
|
|
currentApiSession := context.CtrlClt.GetCurrentApiSession()
|
|
if currentApiSession == nil {
|
|
return &edgeRouterConnResult{
|
|
routerUrl: ingressUrl,
|
|
routerName: routerName,
|
|
err: errors.New("not authenticated to controller"),
|
|
}
|
|
}
|
|
|
|
logger.Debugf("connection to edge router using api session token %s", string(currentApiSession.GetToken()))
|
|
id, err := context.CtrlClt.GetIdentity()
|
|
|
|
if err != nil {
|
|
return &edgeRouterConnResult{
|
|
routerUrl: ingressUrl,
|
|
routerName: routerName,
|
|
err: err,
|
|
}
|
|
}
|
|
|
|
dialerConfig := channel.DialerConfig{
|
|
Identity: identity.NewIdentity(id),
|
|
Endpoint: ingAddr,
|
|
Headers: map[int32][]byte{
|
|
edge.SessionTokenHeader: context.CtrlClt.GetCurrentApiSession().GetToken(),
|
|
},
|
|
TransportConfig: map[interface{}]interface{}{},
|
|
}
|
|
|
|
if context.routerProxy != nil {
|
|
if proxyConfig := context.routerProxy(ingressUrl); proxyConfig != nil {
|
|
dialerConfig.TransportConfig[transport.KeyCachedProxyConfiguration] = proxyConfig
|
|
}
|
|
}
|
|
|
|
dialer := channel.NewClassicDialer(dialerConfig)
|
|
|
|
start := time.Now().UnixNano()
|
|
edgeConn := network.NewEdgeConnFactory(routerName, ingressUrl, context)
|
|
options := channel.DefaultOptions()
|
|
options.ConnectTimeout = 15 * time.Second
|
|
ch, err := channel.NewChannel(fmt.Sprintf("ziti-sdk[router=%v]", ingressUrl), dialer, edgeConn, options)
|
|
if err != nil {
|
|
logger.Error(err)
|
|
return &edgeRouterConnResult{
|
|
routerUrl: ingressUrl,
|
|
routerName: routerName,
|
|
err: err,
|
|
}
|
|
}
|
|
connectTime := time.Duration(time.Now().UnixNano() - start)
|
|
logger.Debugf("routerConn[%s@%s] connected in %d ms", routerName, ingressUrl, connectTime.Milliseconds())
|
|
|
|
if versionHeader, found := ch.Underlay().Headers()[channel.HelloVersionHeader]; found {
|
|
versionInfo, err := versions.StdVersionEncDec.Decode(versionHeader)
|
|
if err != nil {
|
|
pfxlog.Logger().Errorf("could not parse hello version header: %v", err)
|
|
} else {
|
|
pfxlog.Logger().
|
|
WithField("os", versionInfo.OS).
|
|
WithField("arch", versionInfo.Arch).
|
|
WithField("version", versionInfo.Version).
|
|
WithField("revision", versionInfo.Revision).
|
|
WithField("buildDate", versionInfo.BuildDate).
|
|
Debug("connected to edge router")
|
|
}
|
|
}
|
|
|
|
logger.Debugf("connected to %s", ingressUrl)
|
|
|
|
context.Emit(EventRouterConnected, edgeConn.GetRouterName(), edgeConn.Key())
|
|
|
|
useConn := context.routerConnections.Upsert(ingressUrl, edgeConn,
|
|
func(exist bool, oldV edge.RouterConn, newV edge.RouterConn) edge.RouterConn {
|
|
if exist { // use the routerConnection already in the map, close new one
|
|
pfxlog.Logger().Infof("connection to %s already established, closing duplicate connection", ingressUrl)
|
|
go func() {
|
|
if err := newV.Close(); err != nil {
|
|
pfxlog.Logger().Errorf("unable to close router connection (%v)", err)
|
|
}
|
|
}()
|
|
return oldV
|
|
}
|
|
h := context.metrics.Histogram("latency." + ingressUrl)
|
|
h.Update(int64(connectTime))
|
|
|
|
latencyProbeConfig := &latency.ProbeConfig{
|
|
Channel: ch,
|
|
Interval: LatencyCheckInterval,
|
|
Timeout: LatencyCheckTimeout,
|
|
ResultHandler: func(resultNanos int64) {
|
|
h.Update(resultNanos)
|
|
},
|
|
TimeoutHandler: func() {
|
|
logrus.Errorf("latency timeout after [%s]", LatencyCheckTimeout)
|
|
if ch.GetTimeSinceLastRead() > LatencyCheckInterval {
|
|
// No traffic on channel, no response. Close the channel
|
|
logrus.Error("no read traffic on channel since before latency probe was sent, closing channel")
|
|
_ = ch.Close()
|
|
}
|
|
},
|
|
ExitHandler: func() {
|
|
h.Dispose()
|
|
},
|
|
}
|
|
|
|
go latency.ProbeLatencyConfigurable(latencyProbeConfig)
|
|
return newV
|
|
})
|
|
|
|
return &edgeRouterConnResult{
|
|
routerUrl: ingressUrl,
|
|
routerName: routerName,
|
|
routerConnection: useConn,
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) GetServiceId(name string) (string, bool, error) {
|
|
if err := context.ensureApiSession(); err != nil {
|
|
return "", false, fmt.Errorf("failed to get service id: %v", err)
|
|
}
|
|
|
|
if svc, found := context.GetService(name); found {
|
|
return *svc.ID, true, nil
|
|
}
|
|
|
|
return "", false, nil
|
|
}
|
|
|
|
func (context *ContextImpl) GetService(name string) (*rest_model.ServiceDetail, bool) {
|
|
if err := context.ensureApiSession(); err != nil {
|
|
pfxlog.Logger().Warnf("failed to get service: %v", err)
|
|
return nil, false
|
|
}
|
|
|
|
if svc, found := context.services.Get(name); !found {
|
|
return nil, false
|
|
} else {
|
|
return svc, true
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) GetServices() ([]rest_model.ServiceDetail, error) {
|
|
if err := context.ensureApiSession(); err != nil {
|
|
return nil, fmt.Errorf("failed to get services: %v", err)
|
|
}
|
|
|
|
var res []rest_model.ServiceDetail
|
|
context.services.IterCb(func(key string, svc *rest_model.ServiceDetail) {
|
|
res = append(res, *svc)
|
|
})
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func (context *ContextImpl) GetServiceTerminators(serviceName string, offset, limit int) ([]*rest_model.TerminatorClientDetail, int, error) {
|
|
svc, found := context.GetService(serviceName)
|
|
if !found {
|
|
return nil, 0, errors.Errorf("did not find service named %v", serviceName)
|
|
}
|
|
return context.CtrlClt.GetServiceTerminators(svc, offset, limit)
|
|
}
|
|
|
|
func (context *ContextImpl) GetSession(serviceId string) (*rest_model.SessionDetail, error) {
|
|
return context.getOrCreateSession(serviceId, SessionType(SessionDial))
|
|
}
|
|
|
|
func (context *ContextImpl) getOrCreateSession(serviceId string, sessionType SessionType) (*rest_model.SessionDetail, error) {
|
|
sessionKey := fmt.Sprintf("%s:%s", serviceId, sessionType)
|
|
|
|
cache := string(sessionType) == string(SessionDial)
|
|
|
|
// Can't cache Bind sessions, as we use session tokens for routing. If there are multiple binds on a single
|
|
// session routing information will get overwritten
|
|
if cache {
|
|
session, ok := context.sessions.Get(sessionKey)
|
|
if ok {
|
|
return session, nil
|
|
}
|
|
}
|
|
|
|
context.CtrlClt.PostureCache.AddActiveService(serviceId)
|
|
session, err := context.CtrlClt.CreateSession(serviceId, sessionType)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
context.cacheSession("create", session)
|
|
return session, nil
|
|
}
|
|
|
|
func (context *ContextImpl) createSessionWithBackoff(service *rest_model.ServiceDetail, sessionType SessionType, options edge.ConnOptions) (*rest_model.SessionDetail, error) {
|
|
expBackoff := backoff.NewExponentialBackOff()
|
|
|
|
if sessionType == SessionType(rest_model.DialBindDial) {
|
|
expBackoff.InitialInterval = 50 * time.Millisecond
|
|
} else {
|
|
expBackoff.InitialInterval = time.Second
|
|
}
|
|
|
|
expBackoff.MaxInterval = 10 * time.Second
|
|
expBackoff.MaxElapsedTime = options.GetConnectTimeout()
|
|
|
|
var session *rest_model.SessionDetail
|
|
operation := func() error {
|
|
latestSvc, _ := context.services.Get(*service.Name)
|
|
if latestSvc != nil && *latestSvc.ID != *service.ID {
|
|
pfxlog.Logger().
|
|
WithField("serviceName", *service.Name).
|
|
WithField("oldServiceId", *service.ID).
|
|
WithField("newServiceId", *latestSvc.ID).
|
|
Info("service id changed, service was recreated")
|
|
service = latestSvc
|
|
}
|
|
|
|
s, err := context.createSession(service, sessionType)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
session = s
|
|
return nil
|
|
}
|
|
|
|
if session != nil {
|
|
context.CtrlClt.PostureCache.AddActiveService(*service.ID)
|
|
context.cacheSession("create", session)
|
|
}
|
|
|
|
return session, backoff.Retry(operation, expBackoff)
|
|
}
|
|
|
|
func (context *ContextImpl) createSession(service *rest_model.ServiceDetail, sessionType SessionType) (*rest_model.SessionDetail, error) {
|
|
start := time.Now()
|
|
logger := pfxlog.Logger()
|
|
logger.Debugf("establishing %s session to service %s", sessionType, *service.Name)
|
|
session, err := context.getOrCreateSession(*service.ID, sessionType)
|
|
if err != nil {
|
|
logger.WithError(err).WithField("errorType", fmt.Sprintf("%T", err)).Warnf("failure creating %s session to service %s", sessionType, *service.Name)
|
|
var target error = &rest_session.CreateSessionUnauthorized{}
|
|
if errors.As(err, &target) {
|
|
if err := context.Authenticate(); err != nil {
|
|
target = &authentication.AuthenticateUnauthorized{}
|
|
if errors.As(err, &target) {
|
|
return nil, backoff.Permanent(err)
|
|
}
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
target = &rest_session.CreateSessionNotFound{}
|
|
if errors.As(err, &target) {
|
|
if refreshErr := context.refreshServices(false); refreshErr != nil {
|
|
logger.WithError(refreshErr).Info("failed to refresh services after create session returned 404 (likely for service)")
|
|
}
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
elapsed := time.Since(start)
|
|
logger.Debugf("successfully created %s session to service %s in %vms", sessionType, *service.Name, elapsed.Milliseconds())
|
|
return session, nil
|
|
}
|
|
|
|
func (context *ContextImpl) refreshSession(session *rest_model.SessionDetail) (*rest_model.SessionDetail, error) {
|
|
var refreshedSession *rest_model.SessionDetail
|
|
var err error
|
|
if strings.HasPrefix(*session.Token, apis.JwtTokenPrefix) {
|
|
refreshedSession, err = context.CtrlClt.GetSessionFromJwt(*session.Token)
|
|
} else {
|
|
refreshedSession, err = context.CtrlClt.GetSession(*session.ID)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
context.cacheSession("refresh", refreshedSession)
|
|
return refreshedSession, nil
|
|
}
|
|
|
|
func (context *ContextImpl) cacheSession(op string, session *rest_model.SessionDetail) {
|
|
sessionKey := fmt.Sprintf("%s:%s", *session.ServiceID, *session.Type)
|
|
|
|
if *session.Type == SessionDial {
|
|
if op == "create" {
|
|
context.sessions.Set(sessionKey, session)
|
|
} else if op == "refresh" {
|
|
// N.B.: refreshed sessions do not contain token so update stored session object with updated edgeRouters
|
|
isUpdate := false
|
|
val := context.sessions.Upsert(sessionKey, session, func(exist bool, valueInMap *rest_model.SessionDetail, newValue *rest_model.SessionDetail) *rest_model.SessionDetail {
|
|
isUpdate = exist
|
|
return newValue
|
|
})
|
|
if isUpdate {
|
|
existingSession := val
|
|
existingSession.EdgeRouters = session.EdgeRouters
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) deleteServiceSessions(svcId string) {
|
|
context.sessions.Remove(fmt.Sprintf("%s:%s", svcId, SessionBind))
|
|
context.sessions.Remove(fmt.Sprintf("%s:%s", svcId, SessionDial))
|
|
}
|
|
|
|
func (context *ContextImpl) Close() {
|
|
if context.closed.CompareAndSwap(false, true) {
|
|
close(context.closeNotify)
|
|
|
|
context.CloseAllEdgeRouterConns()
|
|
|
|
}
|
|
}
|
|
|
|
func (context *ContextImpl) Metrics() metrics.Registry {
|
|
return context.metrics
|
|
}
|
|
|
|
func (context *ContextImpl) EnrollZitiMfa() (*rest_model.DetailMfa, error) {
|
|
return context.CtrlClt.EnrollMfa()
|
|
}
|
|
|
|
func (context *ContextImpl) VerifyZitiMfa(code string) error {
|
|
return context.CtrlClt.VerifyMfa(code)
|
|
}
|
|
func (context *ContextImpl) RemoveZitiMfa(code string) error {
|
|
return context.CtrlClt.RemoveMfa(code)
|
|
}
|
|
|
|
type waitForNHelper struct {
|
|
count uint
|
|
mgr *listenerManager
|
|
notify chan struct{}
|
|
closed atomic.Bool
|
|
}
|
|
|
|
func (self *waitForNHelper) Notify(eventType ListenEventType) {
|
|
if eventType == ListenerEstablished && self.mgr.listener.GetEstablishedCount() >= self.count {
|
|
if self.closed.CompareAndSwap(false, true) {
|
|
close(self.notify)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (self *waitForNHelper) WaitForN(timeout time.Duration) error {
|
|
select {
|
|
case <-time.After(timeout):
|
|
return fmt.Errorf("timed out waiting for %v listeners to be established, only had %v", self.count, self.mgr.listener.GetEstablishedCount())
|
|
case <-self.notify:
|
|
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func newListenerManager(service *rest_model.ServiceDetail, context *ContextImpl, options *edge.ListenOptions, waitForN uint) (*listenerManager, error) {
|
|
now := time.Now()
|
|
|
|
var keyPair *kx.KeyPair
|
|
if service.EncryptionRequired != nil && *service.EncryptionRequired {
|
|
var err error
|
|
keyPair, err = kx.NewKeyPair()
|
|
if err != nil {
|
|
return nil, errors.Wrapf(err, "unable to create end-to-end encrytpion key-pair while hosting service '%s'", *service.Name)
|
|
}
|
|
}
|
|
|
|
options.KeyPair = keyPair
|
|
options.ListenerId = uuid.NewString()
|
|
|
|
listenerMgr := &listenerManager{
|
|
service: service,
|
|
context: context,
|
|
options: options,
|
|
routerConnections: map[string]edge.RouterConn{},
|
|
connects: map[string]time.Time{},
|
|
connectChan: make(chan *edgeRouterConnResult, 3),
|
|
eventChan: make(chan listenerEvent),
|
|
disconnectedTime: &now,
|
|
}
|
|
|
|
listenerMgr.listener = network.NewMultiListener(service, listenerMgr.GetCurrentSession)
|
|
|
|
var helper *waitForNHelper
|
|
if waitForN > 0 {
|
|
helper = &waitForNHelper{
|
|
count: waitForN,
|
|
mgr: listenerMgr,
|
|
notify: make(chan struct{}),
|
|
}
|
|
listenerMgr.AddObserver(helper)
|
|
defer listenerMgr.RemoveObserver(helper)
|
|
}
|
|
|
|
go listenerMgr.run()
|
|
|
|
if helper != nil {
|
|
if err := helper.WaitForN(options.ConnectTimeout); err != nil {
|
|
result := errorz.MultipleErrors{}
|
|
result = append(result, err)
|
|
if closeErr := listenerMgr.listener.Close(); closeErr != nil {
|
|
result = append(result, closeErr)
|
|
}
|
|
return nil, result.ToError()
|
|
}
|
|
}
|
|
|
|
return listenerMgr, nil
|
|
}
|
|
|
|
type listenerManager struct {
|
|
service *rest_model.ServiceDetail
|
|
context *ContextImpl
|
|
session *rest_model.SessionDetail
|
|
options *edge.ListenOptions
|
|
routerConnections map[string]edge.RouterConn
|
|
connects map[string]time.Time
|
|
listener network.MultiListener
|
|
connectChan chan *edgeRouterConnResult
|
|
eventChan chan listenerEvent
|
|
sessionRefreshInterval time.Duration
|
|
restartSessionRefresh bool
|
|
lastSessionRefresh time.Time
|
|
disconnectedTime *time.Time
|
|
observers concurrenz.CopyOnWriteSlice[ListenEventObserver]
|
|
sessionRefreshBaseLine time.Duration
|
|
}
|
|
|
|
func (mgr *listenerManager) AddObserver(observer ListenEventObserver) {
|
|
mgr.observers.Append(observer)
|
|
}
|
|
|
|
func (mgr *listenerManager) RemoveObserver(observer ListenEventObserver) {
|
|
mgr.observers.Delete(observer)
|
|
}
|
|
|
|
func (mgr *listenerManager) notify(eventType ListenEventType) {
|
|
for _, observer := range mgr.observers.Value() {
|
|
go observer.Notify(eventType)
|
|
}
|
|
}
|
|
|
|
func (mgr *listenerManager) run() {
|
|
log := pfxlog.Logger().WithField("service", stringz.OrEmpty(mgr.service.Name))
|
|
// need to either establish a session, or fail if we can't create one
|
|
for mgr.session == nil {
|
|
mgr.createSessionWithBackoff()
|
|
}
|
|
|
|
mgr.makeMoreListeners()
|
|
|
|
if mgr.options.BindUsingEdgeIdentity {
|
|
mgr.options.Identity = mgr.context.CtrlClt.GetCurrentApiSession().GetIdentityName()
|
|
}
|
|
|
|
if mgr.options.Identity != "" {
|
|
id, err := mgr.context.CtrlClt.GetIdentity()
|
|
|
|
if err != nil {
|
|
panic("could not get identity during run")
|
|
}
|
|
|
|
identitySecret, err := signing.AssertIdentityWithSecret(id.Cert().PrivateKey)
|
|
if err != nil {
|
|
log.WithError(err).Error("failed to sign identity")
|
|
} else {
|
|
mgr.options.IdentitySecret = string(identitySecret)
|
|
}
|
|
}
|
|
|
|
ticker := time.NewTicker(time.Second)
|
|
defer ticker.Stop()
|
|
|
|
var refreshSessionChan <-chan time.Time
|
|
|
|
for !mgr.listener.IsClosed() {
|
|
if mgr.restartSessionRefresh {
|
|
refreshSessionChan = time.After(mgr.sessionRefreshInterval)
|
|
mgr.restartSessionRefresh = false
|
|
}
|
|
|
|
//goland:noinspection GoNilness
|
|
select {
|
|
case routerConnectionResult := <-mgr.connectChan:
|
|
mgr.handleRouterConnectResult(routerConnectionResult)
|
|
case event := <-mgr.eventChan:
|
|
event.handle(mgr)
|
|
case <-refreshSessionChan:
|
|
mgr.refreshSession()
|
|
log.Debugf("next refresh in %s", mgr.sessionRefreshInterval.String())
|
|
refreshSessionChan = time.After(mgr.sessionRefreshInterval)
|
|
mgr.sessionRefreshInterval *= 2
|
|
if mgr.sessionRefreshInterval > 30*time.Minute {
|
|
mgr.sessionRefreshInterval = 30 * time.Minute
|
|
}
|
|
case <-ticker.C:
|
|
mgr.makeMoreListeners()
|
|
case <-mgr.options.GetEventChannel():
|
|
mgr.notify(ListenerEstablished)
|
|
case <-mgr.context.closeNotify:
|
|
mgr.listener.CloseWithError(errors.New("context closed"))
|
|
}
|
|
}
|
|
}
|
|
|
|
func (mgr *listenerManager) sessionRefreshed(session *rest_model.SessionDetail) {
|
|
oldUsableCount := mgr.getUsableEndpointCount(mgr.session)
|
|
newUsableCount := mgr.getUsableEndpointCount(session)
|
|
|
|
if oldUsableCount >= 0 && newUsableCount == 0 {
|
|
mgr.sessionRefreshInterval = time.Duration(5+rand.Intn(10)) * time.Second
|
|
} else if newUsableCount < mgr.options.MaxTerminators {
|
|
// if there's been a change, check reset baseline, as things seem to be influx
|
|
// we'll back-off if there's no further change
|
|
if oldUsableCount != newUsableCount {
|
|
mgr.sessionRefreshBaseLine = 30 * time.Second
|
|
}
|
|
|
|
// vary refresh by half the baseline refresh interval
|
|
halfInterval := mgr.sessionRefreshBaseLine / 2
|
|
wiggleFactor := time.Duration(rand.Int63n(int64(halfInterval)))
|
|
mgr.sessionRefreshInterval = halfInterval + (wiggleFactor * 2)
|
|
|
|
if mgr.sessionRefreshBaseLine < 5*time.Minute {
|
|
mgr.sessionRefreshBaseLine += 30 * time.Second
|
|
}
|
|
} else {
|
|
mgr.sessionRefreshInterval = 30 * time.Minute
|
|
}
|
|
|
|
mgr.session = session
|
|
mgr.restartSessionRefresh = true
|
|
mgr.lastSessionRefresh = time.Now()
|
|
|
|
log := pfxlog.Logger().
|
|
WithField("service", stringz.OrEmpty(mgr.service.Name)).
|
|
WithField("sessionId", stringz.OrEmpty(mgr.session.ID)).
|
|
WithField("usableEndpoints", newUsableCount).
|
|
WithField("nextRefresh", mgr.sessionRefreshInterval.String())
|
|
log.Debug("session refreshed")
|
|
}
|
|
|
|
func (mgr *listenerManager) getUsableEndpointCount(session *rest_model.SessionDetail) int {
|
|
if session == nil {
|
|
return 0
|
|
}
|
|
|
|
count := 0
|
|
for _, edgeRouter := range session.EdgeRouters {
|
|
for _, routerUrl := range edgeRouter.SupportedProtocols {
|
|
if mgr.context.options.isEdgeRouterUrlAccepted(routerUrl) {
|
|
count++
|
|
}
|
|
}
|
|
}
|
|
return count
|
|
}
|
|
|
|
func (mgr *listenerManager) handleRouterConnectResult(result *edgeRouterConnResult) {
|
|
log := pfxlog.Logger().
|
|
WithField("serviceName", *mgr.service.Name).
|
|
WithField("listenerCount", len(mgr.routerConnections)).
|
|
WithField("router", result.routerName).
|
|
WithField("routerUrl", result.routerUrl)
|
|
|
|
log.Debugf("handling router connect result, success? %v", result.routerConnection != nil)
|
|
|
|
delete(mgr.connects, result.routerUrl)
|
|
routerConnection := result.routerConnection
|
|
if routerConnection == nil {
|
|
return
|
|
}
|
|
|
|
if len(mgr.routerConnections) < mgr.options.MaxTerminators {
|
|
if _, ok := mgr.routerConnections[routerConnection.GetRouterName()]; !ok {
|
|
mgr.routerConnections[routerConnection.GetRouterName()] = routerConnection
|
|
log.WithField("listenerCount", len(mgr.routerConnections)).
|
|
Debugf("establishing listener to %s", routerConnection.Key())
|
|
go mgr.createListener(routerConnection, mgr.session)
|
|
}
|
|
} else {
|
|
log.Debug("ignoring connection, already have max connections")
|
|
}
|
|
}
|
|
|
|
func (mgr *listenerManager) createListener(routerConnection edge.RouterConn, session *rest_model.SessionDetail) {
|
|
start := time.Now()
|
|
logger := pfxlog.Logger().WithField("serviceName", *mgr.service.Name).
|
|
WithField("router", routerConnection.GetRouterName())
|
|
svc := mgr.listener.GetService()
|
|
listener, err := routerConnection.Listen(svc, session, mgr.options)
|
|
elapsed := time.Since(start)
|
|
if err == nil {
|
|
logger = logger.WithField("connId", listener.Id())
|
|
logger.Debugf("listener established to %v in %vms", routerConnection.Key(), elapsed.Milliseconds())
|
|
mgr.listener.AddListener(listener, func() {
|
|
select {
|
|
case mgr.eventChan <- &routerConnectionListenFailedEvent{router: routerConnection.GetRouterName()}:
|
|
case <-mgr.context.closeNotify:
|
|
logger.Debugf("listener closed, exiting from createListener")
|
|
}
|
|
})
|
|
mgr.eventChan <- listenSuccessEvent{}
|
|
if !routerConnection.GetBoolHeader(edge.SupportsBindSuccessHeader) {
|
|
select {
|
|
case mgr.options.GetEventChannel() <- &edge.ListenerEvent{EventType: edge.ListenerEstablished}:
|
|
default:
|
|
}
|
|
}
|
|
} else {
|
|
logger.Errorf("creating listener failed after %vms: %v", elapsed.Milliseconds(), err)
|
|
mgr.listener.NotifyOfChildError(err)
|
|
select {
|
|
case mgr.eventChan <- &routerConnectionListenFailedEvent{router: routerConnection.GetRouterName()}:
|
|
case <-mgr.context.closeNotify:
|
|
logger.Debugf("listener closed, exiting from createListener")
|
|
}
|
|
}
|
|
}
|
|
|
|
func (mgr *listenerManager) makeMoreListeners() {
|
|
log := pfxlog.Logger().WithField("service", *mgr.service.Name).WithField("erCount", len(mgr.session.EdgeRouters))
|
|
if mgr.listener.IsClosed() || len(mgr.routerConnections) >= mgr.options.MaxTerminators || len(mgr.session.EdgeRouters) <= len(mgr.routerConnections) {
|
|
log.Trace("not trying to make more connections")
|
|
return
|
|
}
|
|
|
|
for _, edgeRouter := range mgr.session.EdgeRouters {
|
|
if _, ok := mgr.routerConnections[*edgeRouter.Name]; ok {
|
|
log.WithField("router", *edgeRouter.Name).Trace("already connected")
|
|
// already connected to this router
|
|
continue
|
|
}
|
|
|
|
for _, routerUrl := range edgeRouter.SupportedProtocols {
|
|
if !mgr.context.options.isEdgeRouterUrlAccepted(routerUrl) {
|
|
log.WithField("router", *edgeRouter.Name).WithField("url", routerUrl).
|
|
Trace("skipping unusable url")
|
|
continue
|
|
}
|
|
|
|
if connectTime, ok := mgr.connects[routerUrl]; ok && time.Since(connectTime) < 30*time.Second {
|
|
// this url already has a connect in progress
|
|
log.WithField("router", *edgeRouter.Name).WithField("url", routerUrl).
|
|
Trace("connect already in progress")
|
|
continue
|
|
}
|
|
|
|
log.WithField("router", *edgeRouter.Name).WithField("url", routerUrl).
|
|
Trace("attempting to connect to router")
|
|
mgr.connects[routerUrl] = time.Now()
|
|
go mgr.context.handleConnectEdgeRouter(*edgeRouter.Name, routerUrl, mgr.connectChan)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (mgr *listenerManager) refreshSession() {
|
|
if time.Since(mgr.lastSessionRefresh) < 30*time.Second {
|
|
return
|
|
}
|
|
|
|
log := pfxlog.Logger().WithField("service", stringz.OrEmpty(mgr.service.Name))
|
|
if mgr.session == nil {
|
|
log.Debug("establishing initial session")
|
|
mgr.createSessionWithBackoff()
|
|
return
|
|
}
|
|
|
|
log = log.WithField("sessionId", stringz.OrEmpty(mgr.session.ID)).WithField("erCount", len(mgr.session.EdgeRouters))
|
|
log.Debug("starting session refresh")
|
|
session, err := mgr.context.refreshSession(mgr.session)
|
|
|
|
if err != nil {
|
|
var target error = &rest_session.DetailSessionNotFound{}
|
|
if errors.As(err, &target) {
|
|
// try to create new session
|
|
mgr.createSessionWithBackoff()
|
|
return
|
|
}
|
|
|
|
target = &rest_session.DetailSessionUnauthorized{}
|
|
if errors.As(err, &target) {
|
|
log.WithError(err).Debugf("failure refreshing bind session for service %v", mgr.listener.GetServiceName())
|
|
if err := mgr.context.EnsureAuthenticated(mgr.options); err != nil {
|
|
err := fmt.Errorf("unable to establish API session (%w)", err)
|
|
if len(mgr.routerConnections) == 0 {
|
|
mgr.listener.CloseWithError(err)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
session, err = mgr.context.refreshSession(mgr.session)
|
|
if err != nil {
|
|
target = &rest_session.DetailSessionUnauthorized{}
|
|
if errors.As(err, &target) {
|
|
log.WithError(err).Errorf(
|
|
"failure refreshing bind session even after re-authenticating api session. service %v",
|
|
mgr.listener.GetServiceName())
|
|
if len(mgr.routerConnections) == 0 {
|
|
mgr.listener.CloseWithError(err)
|
|
}
|
|
return
|
|
}
|
|
|
|
log.WithError(err).Errorf("failed to to refresh session %v", *mgr.session.ID)
|
|
|
|
// try to create new session
|
|
mgr.createSessionWithBackoff()
|
|
}
|
|
}
|
|
|
|
// token only returned on created, so if we refreshed the session (as opposed to creating a new one) we have to back-fill it on lookups
|
|
if session != nil {
|
|
session.Token = mgr.session.Token
|
|
mgr.sessionRefreshed(session)
|
|
}
|
|
}
|
|
|
|
func (mgr *listenerManager) createSessionWithBackoff() {
|
|
latestSvc, _ := mgr.context.services.Get(*mgr.service.Name)
|
|
if latestSvc != nil && *latestSvc.ID != *mgr.service.ID {
|
|
pfxlog.Logger().
|
|
WithField("serviceName", *mgr.service.Name).
|
|
WithField("oldServiceId", *mgr.service.ID).
|
|
WithField("newServiceId", *latestSvc.ID).
|
|
Info("service id changed, service was recreated")
|
|
mgr.service = latestSvc
|
|
}
|
|
|
|
session, err := mgr.context.createSessionWithBackoff(mgr.service, SessionType(SessionBind), mgr.options)
|
|
if session != nil {
|
|
mgr.sessionRefreshed(session)
|
|
pfxlog.Logger().WithField("session token", *session.Token).Info("new service session")
|
|
} else {
|
|
pfxlog.Logger().WithError(err).Errorf("failed to create bind session for service %v", mgr.service.Name)
|
|
}
|
|
}
|
|
|
|
func (mgr *listenerManager) GetCurrentSession() *rest_model.SessionDetail {
|
|
if mgr.listener.IsClosed() {
|
|
return nil
|
|
}
|
|
event := &getSessionEvent{
|
|
doneC: make(chan struct{}),
|
|
}
|
|
timeout := time.After(5 * time.Second)
|
|
|
|
select {
|
|
case mgr.eventChan <- event:
|
|
case <-timeout:
|
|
return nil
|
|
}
|
|
|
|
select {
|
|
case <-event.doneC:
|
|
return event.session
|
|
case <-timeout:
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type listenerEvent interface {
|
|
handle(mgr *listenerManager)
|
|
}
|
|
|
|
type routerConnectionListenFailedEvent struct {
|
|
router string
|
|
}
|
|
|
|
func (event *routerConnectionListenFailedEvent) handle(mgr *listenerManager) {
|
|
delete(mgr.routerConnections, event.router)
|
|
pfxlog.Logger().WithField("serviceName", *mgr.service.Name).
|
|
WithField("listenerCount", len(mgr.routerConnections)).
|
|
WithField("router", event.router).
|
|
Debugf("child listener connection closed. parent listener closed: %v", mgr.listener.IsClosed())
|
|
now := time.Now()
|
|
if len(mgr.routerConnections) == 0 {
|
|
mgr.disconnectedTime = &now
|
|
}
|
|
mgr.notify(ListenerRemoved)
|
|
if mgr.sessionRefreshInterval > 10*time.Second && time.Since(mgr.lastSessionRefresh) > 10*time.Second {
|
|
mgr.sessionRefreshInterval = time.Duration(100+(rand.Intn(10)*1000)) * time.Millisecond
|
|
mgr.restartSessionRefresh = true
|
|
}
|
|
mgr.refreshSession()
|
|
mgr.makeMoreListeners()
|
|
}
|
|
|
|
type edgeRouterConnResult struct {
|
|
routerUrl string
|
|
routerName string
|
|
routerConnection edge.RouterConn
|
|
err error
|
|
}
|
|
|
|
type listenSuccessEvent struct{}
|
|
|
|
func (event listenSuccessEvent) handle(mgr *listenerManager) {
|
|
mgr.disconnectedTime = nil
|
|
mgr.notify(ListenerAdded)
|
|
}
|
|
|
|
type getSessionEvent struct {
|
|
session *rest_model.SessionDetail
|
|
doneC chan struct{}
|
|
}
|
|
|
|
func (event *getSessionEvent) handle(mgr *listenerManager) {
|
|
defer close(event.doneC)
|
|
event.session = mgr.session
|
|
}
|
|
|
|
type ListenEventType int
|
|
|
|
const (
|
|
ListenerAdded ListenEventType = 1
|
|
ListenerEstablished ListenEventType = 2
|
|
ListenerRemoved ListenEventType = 3
|
|
)
|
|
|
|
type ListenEventObserver interface {
|
|
Notify(eventType ListenEventType)
|
|
}
|