474 lines
11 KiB
Go
474 lines
11 KiB
Go
/*
|
|
Copyright NetFoundry Inc.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package xgress
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"fmt"
|
|
"github.com/openziti/channel/v4"
|
|
"github.com/openziti/foundation/v2/info"
|
|
"github.com/openziti/foundation/v2/uuidz"
|
|
"github.com/pkg/errors"
|
|
"github.com/sirupsen/logrus"
|
|
"math"
|
|
)
|
|
|
|
const (
|
|
MinHeaderKey = 2000
|
|
MaxHeaderKey = MinHeaderKey + int32(math.MaxUint8)
|
|
|
|
HeaderKeyCircuitId = 2256
|
|
HeaderKeySequence = 2257
|
|
HeaderKeyFlags = 2258
|
|
HeaderKeyRecvBufferSize = 2259
|
|
HeaderKeyRTT = 2260
|
|
HeaderPayloadRaw = 2261
|
|
|
|
ContentTypePayloadType = 1100
|
|
ContentTypeAcknowledgementType = 1101
|
|
ContentTypeControlType = 1102
|
|
)
|
|
|
|
var ContentTypeValue = map[string]int32{
|
|
"PayloadType": ContentTypePayloadType,
|
|
"AcknowledgementType": ContentTypeAcknowledgementType,
|
|
"ControlType": ContentTypeControlType,
|
|
}
|
|
|
|
type Originator int32
|
|
|
|
const (
|
|
Initiator Originator = 0
|
|
Terminator Originator = 1
|
|
)
|
|
|
|
func (o Originator) String() string {
|
|
if o == Initiator {
|
|
return "Initiator"
|
|
}
|
|
return "Terminator"
|
|
}
|
|
|
|
func (o Originator) Invert() Originator {
|
|
if o == Initiator {
|
|
return Terminator
|
|
}
|
|
return Initiator
|
|
}
|
|
|
|
type Flag uint32
|
|
|
|
const (
|
|
PayloadFlagCircuitEnd Flag = 1
|
|
PayloadFlagOriginator Flag = 2
|
|
PayloadFlagCircuitStart Flag = 4
|
|
PayloadFlagChunk Flag = 8
|
|
PayloadFlagRetransmit Flag = 16
|
|
)
|
|
|
|
func NewAcknowledgement(circuitId string, originator Originator) *Acknowledgement {
|
|
return &Acknowledgement{
|
|
CircuitId: circuitId,
|
|
Flags: SetOriginatorFlag(0, originator),
|
|
}
|
|
}
|
|
|
|
type Acknowledgement struct {
|
|
CircuitId string
|
|
Flags uint32
|
|
RecvBufferSize uint32
|
|
RTT uint16
|
|
Sequence []int32
|
|
}
|
|
|
|
func (ack *Acknowledgement) GetCircuitId() string {
|
|
return ack.CircuitId
|
|
}
|
|
|
|
func (ack *Acknowledgement) GetFlags() uint32 {
|
|
return ack.Flags
|
|
}
|
|
|
|
func (ack *Acknowledgement) GetOriginator() Originator {
|
|
if isFlagSet(ack.Flags, PayloadFlagOriginator) {
|
|
return Terminator
|
|
}
|
|
return Initiator
|
|
}
|
|
|
|
func (ack *Acknowledgement) GetSequence() []int32 {
|
|
return ack.Sequence
|
|
}
|
|
|
|
func (ack *Acknowledgement) marshallSequence() []byte {
|
|
if len(ack.Sequence) == 0 {
|
|
return nil
|
|
}
|
|
buf := make([]byte, len(ack.Sequence)*4)
|
|
nextWriteBuf := buf
|
|
for _, seq := range ack.Sequence {
|
|
binary.BigEndian.PutUint32(nextWriteBuf, uint32(seq))
|
|
nextWriteBuf = nextWriteBuf[4:]
|
|
}
|
|
return buf
|
|
}
|
|
|
|
func (ack *Acknowledgement) unmarshallSequence(data []byte) error {
|
|
if len(data) == 0 {
|
|
return nil
|
|
}
|
|
|
|
if len(data)%4 != 0 {
|
|
return fmt.Errorf("received sequence with wrong number of bytes: %v", len(data))
|
|
}
|
|
ack.Sequence = make([]int32, len(data)/4)
|
|
|
|
nextReadBuf := data
|
|
for i := range ack.Sequence {
|
|
ack.Sequence[i] = int32(binary.BigEndian.Uint32(nextReadBuf))
|
|
nextReadBuf = nextReadBuf[4:]
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (ack *Acknowledgement) Marshall() *channel.Message {
|
|
msg := channel.NewMessage(ContentTypeAcknowledgementType, ack.marshallSequence())
|
|
msg.PutUint16Header(HeaderKeyRTT, ack.RTT)
|
|
msg.Headers[HeaderKeyCircuitId] = []byte(ack.CircuitId)
|
|
if ack.Flags != 0 {
|
|
msg.PutUint32Header(HeaderKeyFlags, ack.Flags)
|
|
}
|
|
msg.PutUint32Header(HeaderKeyRecvBufferSize, ack.RecvBufferSize)
|
|
return msg
|
|
}
|
|
|
|
func UnmarshallAcknowledgement(msg *channel.Message) (*Acknowledgement, error) {
|
|
ack := &Acknowledgement{}
|
|
|
|
circuitId, ok := msg.Headers[HeaderKeyCircuitId]
|
|
if !ok {
|
|
return nil, fmt.Errorf("no circuitId found in xgress payload message")
|
|
}
|
|
|
|
// If no flags are present, it just means no flags have been set
|
|
flags, _ := msg.GetUint32Header(HeaderKeyFlags)
|
|
|
|
ack.CircuitId = string(circuitId)
|
|
ack.Flags = flags
|
|
if ack.RecvBufferSize, ok = msg.GetUint32Header(HeaderKeyRecvBufferSize); !ok {
|
|
ack.RecvBufferSize = math.MaxUint32
|
|
}
|
|
|
|
ack.RTT, _ = msg.GetUint16Header(HeaderKeyRTT)
|
|
|
|
if err := ack.unmarshallSequence(msg.Body); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return ack, nil
|
|
}
|
|
|
|
func (ack *Acknowledgement) GetLoggerFields() logrus.Fields {
|
|
return logrus.Fields{
|
|
"circuitId": ack.CircuitId,
|
|
"linkRecvBufferSize": ack.RecvBufferSize,
|
|
"seq": fmt.Sprintf("%+v", ack.Sequence),
|
|
"RTT": ack.RTT,
|
|
}
|
|
}
|
|
|
|
type PayloadType byte
|
|
|
|
const (
|
|
PayloadTypeXg PayloadType = 1
|
|
PayloadTypeRtx PayloadType = 2
|
|
PayloadTypeFwd PayloadType = 3
|
|
)
|
|
|
|
type Payload struct {
|
|
CircuitId string
|
|
Flags uint32
|
|
RTT uint16
|
|
Sequence int32
|
|
Headers map[uint8][]byte
|
|
Data []byte
|
|
raw []byte
|
|
}
|
|
|
|
func (payload *Payload) GetSequence() int32 {
|
|
return payload.Sequence
|
|
}
|
|
|
|
func (payload *Payload) Marshall() *channel.Message {
|
|
if payload.raw != nil {
|
|
if payload.raw[0]&RttFlagMask != 0 {
|
|
rtt := uint16(info.NowInMilliseconds())
|
|
b0 := byte(rtt)
|
|
b1 := byte(rtt >> 8)
|
|
payload.raw[2] = b0
|
|
payload.raw[3] = b1
|
|
}
|
|
return channel.NewMessage(channel.ContentTypeRaw, payload.raw)
|
|
}
|
|
|
|
msg := channel.NewMessage(ContentTypePayloadType, payload.Data)
|
|
addPayloadHeadersToMsg(msg, payload.Headers)
|
|
msg.Headers[HeaderKeyCircuitId] = []byte(payload.CircuitId)
|
|
if payload.Flags != 0 {
|
|
msg.PutUint32Header(HeaderKeyFlags, payload.Flags)
|
|
}
|
|
|
|
msg.PutUint64Header(HeaderKeySequence, uint64(payload.Sequence))
|
|
msg.PutUint16Header(HeaderKeyRTT, uint16(info.NowInMilliseconds()))
|
|
|
|
return msg
|
|
}
|
|
|
|
func addPayloadHeadersToMsg(msg *channel.Message, headers map[uint8][]byte) {
|
|
for key, value := range headers {
|
|
msgHeaderKey := MinHeaderKey + int32(key)
|
|
msg.Headers[msgHeaderKey] = value
|
|
}
|
|
}
|
|
|
|
func UnmarshallPayload(msg *channel.Message) (*Payload, error) {
|
|
var headers map[uint8][]byte
|
|
for key, val := range msg.Headers {
|
|
if key >= MinHeaderKey && key <= MaxHeaderKey {
|
|
if headers == nil {
|
|
headers = make(map[uint8][]byte)
|
|
}
|
|
xgressHeaderKey := uint8(key - MinHeaderKey)
|
|
headers[xgressHeaderKey] = val
|
|
}
|
|
}
|
|
|
|
payload := &Payload{
|
|
Headers: headers,
|
|
Data: msg.Body,
|
|
}
|
|
|
|
circuitId, ok := msg.Headers[HeaderKeyCircuitId]
|
|
if !ok {
|
|
return nil, fmt.Errorf("no circuitId found in xgress payload message")
|
|
}
|
|
|
|
// If no flags are present, it just means no flags have been set
|
|
flags, _ := msg.GetUint32Header(HeaderKeyFlags)
|
|
|
|
payload.CircuitId = string(circuitId)
|
|
payload.Flags = flags
|
|
|
|
payload.RTT, _ = msg.GetUint16Header(HeaderKeyRTT)
|
|
|
|
sequence, ok := msg.GetUint64Header(HeaderKeySequence)
|
|
if !ok {
|
|
return nil, fmt.Errorf("no sequence found in xgress payload message")
|
|
}
|
|
payload.Sequence = int32(sequence)
|
|
|
|
if raw, ok := msg.Headers[HeaderPayloadRaw]; ok {
|
|
payload.raw = raw
|
|
}
|
|
|
|
return payload, nil
|
|
}
|
|
|
|
func isFlagSet(flags uint32, flag Flag) bool {
|
|
return Flag(flags)&flag == flag
|
|
}
|
|
|
|
func setPayloadFlag(flags uint32, flag Flag) uint32 {
|
|
return uint32(Flag(flags) | flag)
|
|
}
|
|
|
|
func (payload *Payload) GetCircuitId() string {
|
|
return payload.CircuitId
|
|
}
|
|
|
|
func (payload *Payload) GetFlags() uint32 {
|
|
return payload.Flags
|
|
}
|
|
|
|
func (payload *Payload) IsCircuitEndFlagSet() bool {
|
|
return isFlagSet(payload.Flags, PayloadFlagCircuitEnd)
|
|
}
|
|
|
|
func (payload *Payload) IsCircuitStartFlagSet() bool {
|
|
return isFlagSet(payload.Flags, PayloadFlagCircuitStart)
|
|
}
|
|
|
|
func (payload *Payload) IsRetransmitFlagSet() bool {
|
|
return isFlagSet(payload.Flags, PayloadFlagRetransmit)
|
|
}
|
|
|
|
func (payload *Payload) MarkAsRetransmit() {
|
|
payload.Flags = setPayloadFlag(payload.Flags, PayloadFlagRetransmit)
|
|
}
|
|
|
|
func (payload *Payload) GetOriginator() Originator {
|
|
if isFlagSet(payload.Flags, PayloadFlagOriginator) {
|
|
return Terminator
|
|
}
|
|
return Initiator
|
|
}
|
|
|
|
func SetOriginatorFlag(flags uint32, originator Originator) uint32 {
|
|
if originator == Initiator {
|
|
return ^uint32(PayloadFlagOriginator) & flags
|
|
}
|
|
return uint32(PayloadFlagOriginator) | flags
|
|
}
|
|
|
|
func (payload *Payload) GetLoggerFields() logrus.Fields {
|
|
result := logrus.Fields{
|
|
"circuitId": payload.CircuitId,
|
|
"seq": payload.Sequence,
|
|
"origin": payload.GetOriginator(),
|
|
}
|
|
|
|
if uuidVal, found := payload.Headers[HeaderKeyUUID]; found {
|
|
result["uuid"] = uuidz.ToString(uuidVal)
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
type ControlType byte
|
|
|
|
func (self ControlType) String() string {
|
|
switch self {
|
|
case ControlTypeTraceRoute:
|
|
return "traceroute"
|
|
case ControlTypeTraceRouteResponse:
|
|
return "traceroute_response"
|
|
default:
|
|
return fmt.Sprintf("unhandled: %v", byte(self))
|
|
}
|
|
}
|
|
|
|
const (
|
|
ControlTypeTraceRoute ControlType = 1
|
|
ControlTypeTraceRouteResponse ControlType = 2
|
|
)
|
|
|
|
const (
|
|
ControlHopCount = 20
|
|
ControlHopType = 21
|
|
ControlHopId = 22
|
|
ControlTimestamp = 23
|
|
ControlUserVal = 24
|
|
ControlError = 25
|
|
)
|
|
|
|
type Control struct {
|
|
Type ControlType
|
|
CircuitId string
|
|
Headers channel.Headers
|
|
}
|
|
|
|
func (self *Control) Marshall() *channel.Message {
|
|
msg := channel.NewMessage(ContentTypeControlType, append([]byte{byte(self.Type)}, self.CircuitId...))
|
|
msg.Headers = self.Headers
|
|
return msg
|
|
}
|
|
|
|
func UnmarshallControl(msg *channel.Message) (*Control, error) {
|
|
if len(msg.Body) < 2 {
|
|
return nil, errors.New("control message body too short")
|
|
}
|
|
return &Control{
|
|
Type: ControlType(msg.Body[0]),
|
|
CircuitId: string(msg.Body[1:]),
|
|
Headers: msg.Headers,
|
|
}, nil
|
|
}
|
|
|
|
func (self *Control) IsTypeTraceRoute() bool {
|
|
return self.Type == ControlTypeTraceRoute
|
|
}
|
|
|
|
func (self *Control) IsTypeTraceRouteResponse() bool {
|
|
return self.Type == ControlTypeTraceRouteResponse
|
|
}
|
|
|
|
func (self *Control) DecrementAndGetHop() uint32 {
|
|
hop, _ := self.Headers.GetUint32Header(ControlHopCount)
|
|
if hop == 0 {
|
|
return 0
|
|
}
|
|
hop--
|
|
self.Headers.PutUint32Header(ControlHopCount, hop)
|
|
return hop
|
|
}
|
|
|
|
func (self *Control) CreateTraceResponse(hopType, hopId string) *Control {
|
|
resp := &Control{
|
|
Type: ControlTypeTraceRouteResponse,
|
|
CircuitId: self.CircuitId,
|
|
Headers: self.Headers,
|
|
}
|
|
resp.Headers.PutStringHeader(ControlHopType, hopType)
|
|
resp.Headers.PutStringHeader(ControlHopId, hopId)
|
|
return resp
|
|
}
|
|
|
|
func (self *Control) GetLoggerFields() logrus.Fields {
|
|
result := logrus.Fields{
|
|
"circuitId": self.CircuitId,
|
|
"type": self.Type,
|
|
}
|
|
|
|
if uuidVal, found := self.Headers[HeaderKeyUUID]; found {
|
|
result["uuid"] = uuidz.ToString(uuidVal)
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
func RespondToTraceRequest(headers channel.Headers, hopType, hopId string, response ControlReceiver) {
|
|
resp := &Control{Headers: headers}
|
|
resp.DecrementAndGetHop()
|
|
resp.Headers.PutStringHeader(ControlHopType, hopType)
|
|
resp.Headers.PutStringHeader(ControlHopId, hopId)
|
|
response.HandleControlReceive(ControlTypeTraceRouteResponse, headers)
|
|
}
|
|
|
|
type InvalidTerminatorError struct {
|
|
InnerError error
|
|
}
|
|
|
|
func (e InvalidTerminatorError) Error() string {
|
|
return e.InnerError.Error()
|
|
}
|
|
|
|
func (e InvalidTerminatorError) Unwrap() error {
|
|
return e.InnerError
|
|
}
|
|
|
|
type MisconfiguredTerminatorError struct {
|
|
InnerError error
|
|
}
|
|
|
|
func (e MisconfiguredTerminatorError) Error() string {
|
|
return e.InnerError.Error()
|
|
}
|
|
|
|
func (e MisconfiguredTerminatorError) Unwrap() error {
|
|
return e.InnerError
|
|
}
|