554 lines
13 KiB
Go
554 lines
13 KiB
Go
/*
|
|
Copyright NetFoundry Inc.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package channel
|
|
|
|
import (
|
|
"container/heap"
|
|
"crypto/x509"
|
|
"fmt"
|
|
"github.com/michaelquigley/pfxlog"
|
|
"github.com/openziti/foundation/v2/concurrenz"
|
|
"github.com/openziti/foundation/v2/info"
|
|
"github.com/openziti/foundation/v2/sequence"
|
|
"github.com/pkg/errors"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
flagClosed = 0
|
|
flagRxStarted = 1
|
|
)
|
|
|
|
var connectionSeq = sequence.NewSequence()
|
|
|
|
func NextConnectionId() (string, error) {
|
|
return connectionSeq.NextHash()
|
|
}
|
|
|
|
// Note: if altering this struct, be sure to account for 64 bit alignment on 32 bit arm arch
|
|
// https://pkg.go.dev/sync/atomic#pkg-note-BUG
|
|
// https://github.com/golang/go/issues/36606
|
|
type channelImpl struct {
|
|
lastRead int64
|
|
logicalName string
|
|
underlay Underlay
|
|
options *Options
|
|
sequence *sequence.Sequence
|
|
outQueue chan Sendable
|
|
outPriority *priorityHeap
|
|
waiters waiterMap
|
|
flags concurrenz.AtomicBitSet
|
|
closeNotify chan struct{}
|
|
peekHandlers []PeekHandler
|
|
transformHandlers []TransformHandler
|
|
receiveHandlers map[int32]ReceiveHandler
|
|
errorHandlers []ErrorHandler
|
|
closeHandlers []CloseHandler
|
|
userData interface{}
|
|
replyCounter uint32
|
|
}
|
|
|
|
func NewChannel(logicalName string, underlayFactory UnderlayFactory, bindHandler BindHandler, options *Options) (Channel, error) {
|
|
timeout := time.Duration(0)
|
|
if options != nil {
|
|
timeout = options.ConnectTimeout
|
|
}
|
|
|
|
underlay, err := underlayFactory.Create(timeout)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return NewChannelWithUnderlay(logicalName, underlay, bindHandler, options)
|
|
}
|
|
|
|
func NewChannelWithUnderlay(logicalName string, underlay Underlay, bindHandler BindHandler, options *Options) (Channel, error) {
|
|
outQueueSize := DefaultOutQueueSize
|
|
if options != nil {
|
|
outQueueSize = options.OutQueueSize
|
|
}
|
|
|
|
impl := &channelImpl{
|
|
logicalName: logicalName,
|
|
options: options,
|
|
sequence: sequence.NewSequence(),
|
|
outQueue: make(chan Sendable, outQueueSize),
|
|
outPriority: &priorityHeap{},
|
|
receiveHandlers: map[int32]ReceiveHandler{},
|
|
closeNotify: make(chan struct{}),
|
|
underlay: underlay,
|
|
}
|
|
|
|
heap.Init(impl.outPriority)
|
|
impl.AddTypedReceiveHandler(&pingHandler{})
|
|
|
|
if err := bind(bindHandler, impl); err != nil {
|
|
if closeErr := underlay.Close(); closeErr != nil {
|
|
if !errors.Is(closeErr, net.ErrClosed) {
|
|
pfxlog.ContextLogger(impl.Label()).WithError(err).Warn("error closing underlay")
|
|
}
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
impl.startMultiplex()
|
|
|
|
return impl, nil
|
|
}
|
|
|
|
func AcceptNextChannel(logicalName string, underlayFactory UnderlayFactory, bindHandler BindHandler, options *Options) error {
|
|
underlay, err := underlayFactory.Create(options.ConnectTimeout)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
go func() {
|
|
_, err = NewChannelWithUnderlay(logicalName, underlay, bindHandler, options)
|
|
if err != nil {
|
|
pfxlog.Logger().WithError(err).Errorf("failure accepting channel %v with underlay %v", logicalName, underlay.Label())
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (channel *channelImpl) Id() string {
|
|
return channel.underlay.Id()
|
|
}
|
|
|
|
func (channel *channelImpl) LogicalName() string {
|
|
return channel.logicalName
|
|
}
|
|
|
|
func (channel *channelImpl) SetLogicalName(logicalName string) {
|
|
channel.logicalName = logicalName
|
|
}
|
|
|
|
func (channel *channelImpl) ConnectionId() string {
|
|
return channel.underlay.ConnectionId()
|
|
}
|
|
|
|
func (channel *channelImpl) Certificates() []*x509.Certificate {
|
|
return channel.underlay.Certificates()
|
|
}
|
|
|
|
func (channel *channelImpl) Headers() map[int32][]byte {
|
|
return channel.underlay.Headers()
|
|
}
|
|
|
|
func (channel *channelImpl) Label() string {
|
|
if channel.underlay != nil {
|
|
return fmt.Sprintf("ch{%s}->%s", channel.LogicalName(), channel.underlay.Label())
|
|
} else {
|
|
return fmt.Sprintf("ch{%s}->{}", channel.LogicalName())
|
|
}
|
|
}
|
|
|
|
func (channel *channelImpl) GetChannel() Channel {
|
|
return channel
|
|
}
|
|
|
|
func (channel *channelImpl) Bind(h BindHandler) error {
|
|
return h.BindChannel(channel)
|
|
}
|
|
|
|
func (channel *channelImpl) AddPeekHandler(h PeekHandler) {
|
|
channel.peekHandlers = append(channel.peekHandlers, h)
|
|
}
|
|
|
|
func (channel *channelImpl) AddTransformHandler(h TransformHandler) {
|
|
channel.transformHandlers = append(channel.transformHandlers, h)
|
|
}
|
|
|
|
func (channel *channelImpl) AddTypedReceiveHandler(h TypedReceiveHandler) {
|
|
channel.receiveHandlers[h.ContentType()] = h
|
|
}
|
|
|
|
func (channel *channelImpl) AddReceiveHandler(contentType int32, h ReceiveHandler) {
|
|
channel.receiveHandlers[contentType] = h
|
|
}
|
|
|
|
func (channel *channelImpl) AddReceiveHandlerF(contentType int32, h ReceiveHandlerF) {
|
|
channel.AddReceiveHandler(contentType, h)
|
|
}
|
|
|
|
func (channel *channelImpl) AddErrorHandler(h ErrorHandler) {
|
|
channel.errorHandlers = append(channel.errorHandlers, h)
|
|
}
|
|
|
|
func (channel *channelImpl) AddCloseHandler(h CloseHandler) {
|
|
channel.closeHandlers = append(channel.closeHandlers, h)
|
|
}
|
|
|
|
func (channel *channelImpl) SetUserData(data interface{}) {
|
|
channel.userData = data
|
|
}
|
|
|
|
func (channel *channelImpl) GetUserData() interface{} {
|
|
return channel.userData
|
|
}
|
|
|
|
func (channel *channelImpl) Close() error {
|
|
if channel.flags.CompareAndSet(flagClosed, false, true) {
|
|
pfxlog.ContextLogger(channel.Label()).Debug("closing channel")
|
|
|
|
close(channel.closeNotify)
|
|
|
|
for _, peekHandler := range channel.peekHandlers {
|
|
peekHandler.Close(channel)
|
|
}
|
|
|
|
if len(channel.closeHandlers) > 0 {
|
|
for _, closeHandler := range channel.closeHandlers {
|
|
closeHandler.HandleClose(channel)
|
|
}
|
|
} else {
|
|
pfxlog.ContextLogger(channel.Label()).Debug("no close handlers")
|
|
}
|
|
|
|
return channel.underlay.Close()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (channel *channelImpl) IsClosed() bool {
|
|
return channel.flags.IsSet(flagClosed)
|
|
}
|
|
|
|
func (channel *channelImpl) Send(s Sendable) error {
|
|
if err := s.Context().Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
s.SetSequence(int32(channel.sequence.Next()))
|
|
|
|
select {
|
|
case <-s.Context().Done():
|
|
if err := s.Context().Err(); err != nil {
|
|
return TimeoutError{error: errors.Wrap(err, "timeout waiting to put message in send queue")}
|
|
}
|
|
return TimeoutError{error: errors.New("timeout waiting to put message in send queue")}
|
|
case <-channel.closeNotify:
|
|
return ClosedError{}
|
|
case channel.outQueue <- s:
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (channel *channelImpl) TrySend(s Sendable) (bool, error) {
|
|
if err := s.Context().Err(); err != nil {
|
|
return false, err
|
|
}
|
|
|
|
s.SetSequence(int32(channel.sequence.Next()))
|
|
|
|
select {
|
|
case <-s.Context().Done():
|
|
if err := s.Context().Err(); err != nil {
|
|
return false, TimeoutError{errors.Wrap(err, "timeout waiting to put message in send queue")}
|
|
}
|
|
return false, TimeoutError{error: errors.New("timeout waiting to put message in send queue")}
|
|
case <-channel.closeNotify:
|
|
return false, ClosedError{}
|
|
case channel.outQueue <- s:
|
|
return true, nil
|
|
default:
|
|
return false, nil
|
|
}
|
|
}
|
|
|
|
func (channel *channelImpl) Underlay() Underlay {
|
|
return channel.underlay
|
|
}
|
|
|
|
func (channel *channelImpl) startMultiplex() {
|
|
for _, peekHandler := range channel.peekHandlers {
|
|
peekHandler.Connect(channel, "")
|
|
}
|
|
|
|
go channel.rxer()
|
|
go channel.txer()
|
|
}
|
|
|
|
func (channel *channelImpl) rxer() {
|
|
if !channel.flags.CompareAndSet(flagRxStarted, false, true) {
|
|
return
|
|
}
|
|
|
|
log := pfxlog.ContextLogger(channel.Label())
|
|
log.Debug("started")
|
|
defer log.Debug("exited")
|
|
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
panic(r)
|
|
}
|
|
_ = channel.Close()
|
|
}()
|
|
|
|
defer channel.waiters.clear()
|
|
|
|
for {
|
|
m, err := channel.underlay.Rx()
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
log.WithError(err).Debug("EOF")
|
|
} else if channel.IsClosed() {
|
|
log.WithError(err).Debug("rx error")
|
|
} else {
|
|
log.WithError(err).Error("rx error")
|
|
}
|
|
return
|
|
}
|
|
channel.rx(m)
|
|
}
|
|
}
|
|
|
|
func (channel *channelImpl) rx(m *Message) {
|
|
log := pfxlog.ContextLogger(channel.Label())
|
|
|
|
now := info.NowInMilliseconds()
|
|
atomic.StoreInt64(&channel.lastRead, now)
|
|
|
|
for _, transformHandler := range channel.transformHandlers {
|
|
transformHandler.Rx(m, channel)
|
|
}
|
|
|
|
for _, peekHandler := range channel.peekHandlers {
|
|
peekHandler.Rx(m, channel)
|
|
}
|
|
|
|
handled := false
|
|
if m.IsReply() {
|
|
channel.replyCounter++
|
|
if channel.replyCounter%100 == 0 && channel.waiters.Size() > 1000 {
|
|
channel.waiters.reapExpired(now)
|
|
}
|
|
replyFor := m.ReplyFor()
|
|
if replyReceiver := channel.waiters.RemoveWaiter(replyFor); replyReceiver != nil {
|
|
log.Tracef("waiter found for message. type [%v], sequence [%v], replyFor [%v]", m.ContentType, m.sequence, replyFor)
|
|
replyReceiver.AcceptReply(m)
|
|
handled = true
|
|
} else {
|
|
log.Debugf("no waiter for message. type [%v], sequence [%v], replyFor [%v]", m.ContentType, m.sequence, replyFor)
|
|
}
|
|
}
|
|
|
|
if !handled {
|
|
if receiveHandler, found := channel.receiveHandlers[m.ContentType]; found {
|
|
receiveHandler.HandleReceive(m, channel)
|
|
|
|
} else if anyHandler, found := channel.receiveHandlers[AnyContentType]; found {
|
|
anyHandler.HandleReceive(m, channel)
|
|
} else {
|
|
log.Warnf("dropped message. type [%d], sequence [%v], replyFor [%v]", m.ContentType, m.sequence, m.ReplyFor())
|
|
}
|
|
}
|
|
}
|
|
|
|
func (channel *channelImpl) txer() {
|
|
log := pfxlog.ContextLogger(channel.Label())
|
|
defer log.Debug("exited")
|
|
log.Debug("started")
|
|
|
|
defer func() { _ = channel.Close() }()
|
|
|
|
var writeTimeout time.Duration
|
|
if channel.options != nil {
|
|
writeTimeout = channel.options.WriteTimeout
|
|
}
|
|
|
|
for {
|
|
done := false
|
|
selecting := true
|
|
|
|
count := 0
|
|
|
|
select {
|
|
case pm := <-channel.outQueue:
|
|
heap.Push(channel.outPriority, pm)
|
|
count++
|
|
case <-channel.closeNotify:
|
|
done = true
|
|
selecting = false
|
|
}
|
|
|
|
for selecting && count < 64 {
|
|
select {
|
|
case pm := <-channel.outQueue:
|
|
heap.Push(channel.outPriority, pm)
|
|
count++
|
|
case <-channel.closeNotify:
|
|
done = true
|
|
selecting = false
|
|
default:
|
|
selecting = false
|
|
}
|
|
}
|
|
|
|
for channel.outPriority.Len() > 0 {
|
|
sendable := heap.Pop(channel.outPriority).(Sendable)
|
|
if err := channel.tx(sendable, writeTimeout); err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
if done {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (channel *channelImpl) tx(sendable Sendable, writeTimeout time.Duration) error {
|
|
log := pfxlog.ContextLogger(channel.Label())
|
|
|
|
sendListener := sendable.SendListener()
|
|
m := sendable.Msg()
|
|
|
|
if err := sendable.Context().Err(); err != nil {
|
|
sendListener.NotifyErr(TimeoutError{err})
|
|
return nil
|
|
}
|
|
|
|
sendListener.NotifyBeforeWrite()
|
|
|
|
if m == nil { // allow nil message in Sendable so we can send tracers to check time from send to write
|
|
return nil
|
|
}
|
|
|
|
for _, transformHandler := range channel.transformHandlers {
|
|
transformHandler.Tx(m, channel)
|
|
}
|
|
|
|
channel.waiters.AddWaiter(sendable)
|
|
|
|
var err error
|
|
if writeTimeout > 0 {
|
|
if err = channel.underlay.SetWriteTimeout(writeTimeout); err != nil {
|
|
log.WithError(err).Errorf("unable to set write timeout")
|
|
sendListener.NotifyErr(err)
|
|
return err
|
|
}
|
|
}
|
|
|
|
err = channel.underlay.Tx(m)
|
|
|
|
if err != nil {
|
|
log.WithError(err).Errorf("write error")
|
|
sendListener.NotifyErr(err)
|
|
|
|
for _, errorHandler := range channel.errorHandlers {
|
|
errorHandler.HandleError(err, channel)
|
|
}
|
|
|
|
sendListener.NotifyAfterWrite()
|
|
|
|
return err
|
|
}
|
|
|
|
for _, peekHandler := range channel.peekHandlers {
|
|
peekHandler.Tx(m, channel)
|
|
}
|
|
|
|
sendListener.NotifyAfterWrite()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (channel *channelImpl) GetTimeSinceLastRead() time.Duration {
|
|
return time.Duration(info.NowInMilliseconds()-atomic.LoadInt64(&channel.lastRead)) * time.Millisecond
|
|
}
|
|
|
|
type waiter struct {
|
|
replyReceiver ReplyReceiver
|
|
ttlMs int64
|
|
}
|
|
|
|
type waiterMap struct {
|
|
m sync.Map
|
|
size int32
|
|
}
|
|
|
|
func (self *waiterMap) Size() int32 {
|
|
return atomic.LoadInt32(&self.size)
|
|
}
|
|
|
|
func (self *waiterMap) AddWaiter(sendable Sendable) {
|
|
if replyReceiver := sendable.ReplyReceiver(); replyReceiver != nil {
|
|
w := &waiter{
|
|
replyReceiver: replyReceiver,
|
|
}
|
|
|
|
if deadline, hasDeadline := sendable.Context().Deadline(); hasDeadline {
|
|
w.ttlMs = deadline.UnixMilli()
|
|
} else {
|
|
w.ttlMs = info.NowInMilliseconds() + 30_000
|
|
}
|
|
|
|
self.m.Store(sendable.Msg().Sequence(), w)
|
|
atomic.AddInt32(&self.size, 1)
|
|
}
|
|
}
|
|
|
|
func (self *waiterMap) RemoveWaiter(seq int32) ReplyReceiver {
|
|
if result, found := self.m.LoadAndDelete(seq); found {
|
|
w := result.(*waiter)
|
|
atomic.AddInt32(&self.size, -1)
|
|
return w.replyReceiver
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (self *waiterMap) reapExpired(now int64) {
|
|
var deleteCount int32
|
|
self.m.Range(func(key, value interface{}) bool {
|
|
if w, ok := value.(*waiter); !ok || w.ttlMs < now {
|
|
self.m.Delete(key)
|
|
deleteCount++
|
|
pfxlog.Logger().Debugf("removed waiter for %v. ttl: %v, now: %v", key, w.ttlMs, now)
|
|
}
|
|
return true
|
|
})
|
|
atomic.AddInt32(&self.size, -deleteCount)
|
|
}
|
|
|
|
func (self *waiterMap) clear() {
|
|
self.m.Range(func(k, v interface{}) bool {
|
|
self.m.Delete(k)
|
|
return true
|
|
})
|
|
}
|
|
|
|
func bind(bindHandler BindHandler, binding Binding) error {
|
|
if bindHandler == nil {
|
|
return nil
|
|
}
|
|
|
|
if err := bindHandler.BindChannel(binding); err != nil {
|
|
if closeErr := binding.GetChannel().Close(); closeErr != nil {
|
|
pfxlog.ContextLogger(binding.GetChannel().Label()).WithError(err).Warn("error closing channel after bind failure")
|
|
}
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|