EdgexAgent/device-ble-go/vendor/github.com/openziti/channel/v4/impl.go
2025-07-10 20:40:32 +08:00

554 lines
13 KiB
Go

/*
Copyright NetFoundry Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package channel
import (
"container/heap"
"crypto/x509"
"fmt"
"github.com/michaelquigley/pfxlog"
"github.com/openziti/foundation/v2/concurrenz"
"github.com/openziti/foundation/v2/info"
"github.com/openziti/foundation/v2/sequence"
"github.com/pkg/errors"
"io"
"net"
"sync"
"sync/atomic"
"time"
)
const (
flagClosed = 0
flagRxStarted = 1
)
var connectionSeq = sequence.NewSequence()
func NextConnectionId() (string, error) {
return connectionSeq.NextHash()
}
// Note: if altering this struct, be sure to account for 64 bit alignment on 32 bit arm arch
// https://pkg.go.dev/sync/atomic#pkg-note-BUG
// https://github.com/golang/go/issues/36606
type channelImpl struct {
lastRead int64
logicalName string
underlay Underlay
options *Options
sequence *sequence.Sequence
outQueue chan Sendable
outPriority *priorityHeap
waiters waiterMap
flags concurrenz.AtomicBitSet
closeNotify chan struct{}
peekHandlers []PeekHandler
transformHandlers []TransformHandler
receiveHandlers map[int32]ReceiveHandler
errorHandlers []ErrorHandler
closeHandlers []CloseHandler
userData interface{}
replyCounter uint32
}
func NewChannel(logicalName string, underlayFactory UnderlayFactory, bindHandler BindHandler, options *Options) (Channel, error) {
timeout := time.Duration(0)
if options != nil {
timeout = options.ConnectTimeout
}
underlay, err := underlayFactory.Create(timeout)
if err != nil {
return nil, err
}
return NewChannelWithUnderlay(logicalName, underlay, bindHandler, options)
}
func NewChannelWithUnderlay(logicalName string, underlay Underlay, bindHandler BindHandler, options *Options) (Channel, error) {
outQueueSize := DefaultOutQueueSize
if options != nil {
outQueueSize = options.OutQueueSize
}
impl := &channelImpl{
logicalName: logicalName,
options: options,
sequence: sequence.NewSequence(),
outQueue: make(chan Sendable, outQueueSize),
outPriority: &priorityHeap{},
receiveHandlers: map[int32]ReceiveHandler{},
closeNotify: make(chan struct{}),
underlay: underlay,
}
heap.Init(impl.outPriority)
impl.AddTypedReceiveHandler(&pingHandler{})
if err := bind(bindHandler, impl); err != nil {
if closeErr := underlay.Close(); closeErr != nil {
if !errors.Is(closeErr, net.ErrClosed) {
pfxlog.ContextLogger(impl.Label()).WithError(err).Warn("error closing underlay")
}
}
return nil, err
}
impl.startMultiplex()
return impl, nil
}
func AcceptNextChannel(logicalName string, underlayFactory UnderlayFactory, bindHandler BindHandler, options *Options) error {
underlay, err := underlayFactory.Create(options.ConnectTimeout)
if err != nil {
return err
}
go func() {
_, err = NewChannelWithUnderlay(logicalName, underlay, bindHandler, options)
if err != nil {
pfxlog.Logger().WithError(err).Errorf("failure accepting channel %v with underlay %v", logicalName, underlay.Label())
}
}()
return nil
}
func (channel *channelImpl) Id() string {
return channel.underlay.Id()
}
func (channel *channelImpl) LogicalName() string {
return channel.logicalName
}
func (channel *channelImpl) SetLogicalName(logicalName string) {
channel.logicalName = logicalName
}
func (channel *channelImpl) ConnectionId() string {
return channel.underlay.ConnectionId()
}
func (channel *channelImpl) Certificates() []*x509.Certificate {
return channel.underlay.Certificates()
}
func (channel *channelImpl) Headers() map[int32][]byte {
return channel.underlay.Headers()
}
func (channel *channelImpl) Label() string {
if channel.underlay != nil {
return fmt.Sprintf("ch{%s}->%s", channel.LogicalName(), channel.underlay.Label())
} else {
return fmt.Sprintf("ch{%s}->{}", channel.LogicalName())
}
}
func (channel *channelImpl) GetChannel() Channel {
return channel
}
func (channel *channelImpl) Bind(h BindHandler) error {
return h.BindChannel(channel)
}
func (channel *channelImpl) AddPeekHandler(h PeekHandler) {
channel.peekHandlers = append(channel.peekHandlers, h)
}
func (channel *channelImpl) AddTransformHandler(h TransformHandler) {
channel.transformHandlers = append(channel.transformHandlers, h)
}
func (channel *channelImpl) AddTypedReceiveHandler(h TypedReceiveHandler) {
channel.receiveHandlers[h.ContentType()] = h
}
func (channel *channelImpl) AddReceiveHandler(contentType int32, h ReceiveHandler) {
channel.receiveHandlers[contentType] = h
}
func (channel *channelImpl) AddReceiveHandlerF(contentType int32, h ReceiveHandlerF) {
channel.AddReceiveHandler(contentType, h)
}
func (channel *channelImpl) AddErrorHandler(h ErrorHandler) {
channel.errorHandlers = append(channel.errorHandlers, h)
}
func (channel *channelImpl) AddCloseHandler(h CloseHandler) {
channel.closeHandlers = append(channel.closeHandlers, h)
}
func (channel *channelImpl) SetUserData(data interface{}) {
channel.userData = data
}
func (channel *channelImpl) GetUserData() interface{} {
return channel.userData
}
func (channel *channelImpl) Close() error {
if channel.flags.CompareAndSet(flagClosed, false, true) {
pfxlog.ContextLogger(channel.Label()).Debug("closing channel")
close(channel.closeNotify)
for _, peekHandler := range channel.peekHandlers {
peekHandler.Close(channel)
}
if len(channel.closeHandlers) > 0 {
for _, closeHandler := range channel.closeHandlers {
closeHandler.HandleClose(channel)
}
} else {
pfxlog.ContextLogger(channel.Label()).Debug("no close handlers")
}
return channel.underlay.Close()
}
return nil
}
func (channel *channelImpl) IsClosed() bool {
return channel.flags.IsSet(flagClosed)
}
func (channel *channelImpl) Send(s Sendable) error {
if err := s.Context().Err(); err != nil {
return err
}
s.SetSequence(int32(channel.sequence.Next()))
select {
case <-s.Context().Done():
if err := s.Context().Err(); err != nil {
return TimeoutError{error: errors.Wrap(err, "timeout waiting to put message in send queue")}
}
return TimeoutError{error: errors.New("timeout waiting to put message in send queue")}
case <-channel.closeNotify:
return ClosedError{}
case channel.outQueue <- s:
}
return nil
}
func (channel *channelImpl) TrySend(s Sendable) (bool, error) {
if err := s.Context().Err(); err != nil {
return false, err
}
s.SetSequence(int32(channel.sequence.Next()))
select {
case <-s.Context().Done():
if err := s.Context().Err(); err != nil {
return false, TimeoutError{errors.Wrap(err, "timeout waiting to put message in send queue")}
}
return false, TimeoutError{error: errors.New("timeout waiting to put message in send queue")}
case <-channel.closeNotify:
return false, ClosedError{}
case channel.outQueue <- s:
return true, nil
default:
return false, nil
}
}
func (channel *channelImpl) Underlay() Underlay {
return channel.underlay
}
func (channel *channelImpl) startMultiplex() {
for _, peekHandler := range channel.peekHandlers {
peekHandler.Connect(channel, "")
}
go channel.rxer()
go channel.txer()
}
func (channel *channelImpl) rxer() {
if !channel.flags.CompareAndSet(flagRxStarted, false, true) {
return
}
log := pfxlog.ContextLogger(channel.Label())
log.Debug("started")
defer log.Debug("exited")
defer func() {
if r := recover(); r != nil {
panic(r)
}
_ = channel.Close()
}()
defer channel.waiters.clear()
for {
m, err := channel.underlay.Rx()
if err != nil {
if err == io.EOF {
log.WithError(err).Debug("EOF")
} else if channel.IsClosed() {
log.WithError(err).Debug("rx error")
} else {
log.WithError(err).Error("rx error")
}
return
}
channel.rx(m)
}
}
func (channel *channelImpl) rx(m *Message) {
log := pfxlog.ContextLogger(channel.Label())
now := info.NowInMilliseconds()
atomic.StoreInt64(&channel.lastRead, now)
for _, transformHandler := range channel.transformHandlers {
transformHandler.Rx(m, channel)
}
for _, peekHandler := range channel.peekHandlers {
peekHandler.Rx(m, channel)
}
handled := false
if m.IsReply() {
channel.replyCounter++
if channel.replyCounter%100 == 0 && channel.waiters.Size() > 1000 {
channel.waiters.reapExpired(now)
}
replyFor := m.ReplyFor()
if replyReceiver := channel.waiters.RemoveWaiter(replyFor); replyReceiver != nil {
log.Tracef("waiter found for message. type [%v], sequence [%v], replyFor [%v]", m.ContentType, m.sequence, replyFor)
replyReceiver.AcceptReply(m)
handled = true
} else {
log.Debugf("no waiter for message. type [%v], sequence [%v], replyFor [%v]", m.ContentType, m.sequence, replyFor)
}
}
if !handled {
if receiveHandler, found := channel.receiveHandlers[m.ContentType]; found {
receiveHandler.HandleReceive(m, channel)
} else if anyHandler, found := channel.receiveHandlers[AnyContentType]; found {
anyHandler.HandleReceive(m, channel)
} else {
log.Warnf("dropped message. type [%d], sequence [%v], replyFor [%v]", m.ContentType, m.sequence, m.ReplyFor())
}
}
}
func (channel *channelImpl) txer() {
log := pfxlog.ContextLogger(channel.Label())
defer log.Debug("exited")
log.Debug("started")
defer func() { _ = channel.Close() }()
var writeTimeout time.Duration
if channel.options != nil {
writeTimeout = channel.options.WriteTimeout
}
for {
done := false
selecting := true
count := 0
select {
case pm := <-channel.outQueue:
heap.Push(channel.outPriority, pm)
count++
case <-channel.closeNotify:
done = true
selecting = false
}
for selecting && count < 64 {
select {
case pm := <-channel.outQueue:
heap.Push(channel.outPriority, pm)
count++
case <-channel.closeNotify:
done = true
selecting = false
default:
selecting = false
}
}
for channel.outPriority.Len() > 0 {
sendable := heap.Pop(channel.outPriority).(Sendable)
if err := channel.tx(sendable, writeTimeout); err != nil {
return
}
}
if done {
return
}
}
}
func (channel *channelImpl) tx(sendable Sendable, writeTimeout time.Duration) error {
log := pfxlog.ContextLogger(channel.Label())
sendListener := sendable.SendListener()
m := sendable.Msg()
if err := sendable.Context().Err(); err != nil {
sendListener.NotifyErr(TimeoutError{err})
return nil
}
sendListener.NotifyBeforeWrite()
if m == nil { // allow nil message in Sendable so we can send tracers to check time from send to write
return nil
}
for _, transformHandler := range channel.transformHandlers {
transformHandler.Tx(m, channel)
}
channel.waiters.AddWaiter(sendable)
var err error
if writeTimeout > 0 {
if err = channel.underlay.SetWriteTimeout(writeTimeout); err != nil {
log.WithError(err).Errorf("unable to set write timeout")
sendListener.NotifyErr(err)
return err
}
}
err = channel.underlay.Tx(m)
if err != nil {
log.WithError(err).Errorf("write error")
sendListener.NotifyErr(err)
for _, errorHandler := range channel.errorHandlers {
errorHandler.HandleError(err, channel)
}
sendListener.NotifyAfterWrite()
return err
}
for _, peekHandler := range channel.peekHandlers {
peekHandler.Tx(m, channel)
}
sendListener.NotifyAfterWrite()
return nil
}
func (channel *channelImpl) GetTimeSinceLastRead() time.Duration {
return time.Duration(info.NowInMilliseconds()-atomic.LoadInt64(&channel.lastRead)) * time.Millisecond
}
type waiter struct {
replyReceiver ReplyReceiver
ttlMs int64
}
type waiterMap struct {
m sync.Map
size int32
}
func (self *waiterMap) Size() int32 {
return atomic.LoadInt32(&self.size)
}
func (self *waiterMap) AddWaiter(sendable Sendable) {
if replyReceiver := sendable.ReplyReceiver(); replyReceiver != nil {
w := &waiter{
replyReceiver: replyReceiver,
}
if deadline, hasDeadline := sendable.Context().Deadline(); hasDeadline {
w.ttlMs = deadline.UnixMilli()
} else {
w.ttlMs = info.NowInMilliseconds() + 30_000
}
self.m.Store(sendable.Msg().Sequence(), w)
atomic.AddInt32(&self.size, 1)
}
}
func (self *waiterMap) RemoveWaiter(seq int32) ReplyReceiver {
if result, found := self.m.LoadAndDelete(seq); found {
w := result.(*waiter)
atomic.AddInt32(&self.size, -1)
return w.replyReceiver
}
return nil
}
func (self *waiterMap) reapExpired(now int64) {
var deleteCount int32
self.m.Range(func(key, value interface{}) bool {
if w, ok := value.(*waiter); !ok || w.ttlMs < now {
self.m.Delete(key)
deleteCount++
pfxlog.Logger().Debugf("removed waiter for %v. ttl: %v, now: %v", key, w.ttlMs, now)
}
return true
})
atomic.AddInt32(&self.size, -deleteCount)
}
func (self *waiterMap) clear() {
self.m.Range(func(k, v interface{}) bool {
self.m.Delete(k)
return true
})
}
func bind(bindHandler BindHandler, binding Binding) error {
if bindHandler == nil {
return nil
}
if err := bindHandler.BindChannel(binding); err != nil {
if closeErr := binding.GetChannel().Close(); closeErr != nil {
pfxlog.ContextLogger(binding.GetChannel().Label()).WithError(err).Warn("error closing channel after bind failure")
}
return err
}
return nil
}