From 25e89e276ec50f11382e0d661dc0c2868c0dffbe Mon Sep 17 00:00:00 2001 From: aJC7737 <275433314+aJC7737@users.noreply.github.com> Date: Sun, 12 Apr 2026 10:22:35 +0000 Subject: [PATCH] feat: support passing proxy_protocol to reality backend --- docs/config.yaml | 7 +- listener/inbound/reality.go | 2 + listener/reality/reality.go | 138 ++++++++++++++++++++++++++++- listener/reality/reality_test.go | 144 +++++++++++++++++++++++++++++++ 4 files changed, 288 insertions(+), 3 deletions(-) create mode 100644 listener/reality/reality_test.go diff --git a/docs/config.yaml b/docs/config.yaml index 9f2aa6a920..ce1a0373d9 100644 --- a/docs/config.yaml +++ b/docs/config.yaml @@ -771,7 +771,7 @@ proxies: # socks5 # max-connections: 1 # Maximum connections. Conflict with max-streams. # min-streams: 0 # Minimum multiplexed streams in a connection before opening a new connection. Conflict with max-streams. # max-streams: 0 # Maximum multiplexed streams in a connection before opening a new connection. Conflict with max-connections and min-streams. - + reality-opts: public-key: CrrQSjAG_YkHLwvM2M-7XkKJilgL5upBKCp0od0tLhE short-id: 10f897e26c4b9478 @@ -1182,7 +1182,7 @@ proxies: # socks5 - name: sudoku type: sudoku server: server_ip/domain # 1.2.3.4 or domain - port: 443 + port: 443 key: "" # 如果你使用sudoku生成的ED25519密钥对,请填写密钥对中的私钥,否则填入和服务端相同的uuid aead-method: chacha20-poly1305 # 可选:chacha20-poly1305、aes-128-gcm、none(不建议;且 enable-pure-downlink=false 时不可用) padding-min: 2 # 最小填充率(0-100) @@ -1614,6 +1614,7 @@ listeners: # - 0123456789abcdef # server-names: # - test.com + # proxy-protocol: 0 # 可选, 默认是0为关闭, 1/2为向后端发送Proxy Protocol v1/v2 # #下列两个 limit 为选填,可对未通过验证的回落连接限速,bytesPerSec 默认为 0 即不启用 # #回落限速是一种特征,不建议启用,如果您是面板/一键脚本开发者,务必让这些参数随机化 # limit-fallback-upload: @@ -1717,6 +1718,7 @@ listeners: - 0123456789abcdef server-names: - test.com + proxy-protocol: 0 # 可选, 默认是0为关闭, 1/2为向后端发送Proxy Protocol v1/v2 #下列两个 limit 为选填,可对未通过验证的回落连接限速,bytesPerSec 默认为 0 即不启用 #回落限速是一种特征,不建议启用,如果您是面板/一键脚本开发者,务必让这些参数随机化 limit-fallback-upload: @@ -1817,6 +1819,7 @@ listeners: # - 0123456789abcdef # server-names: # - test.com + # proxy-protocol: 0 # 可选, 默认是0为关闭, 1/2为向后端发送Proxy Protocol v1/v2 # #下列两个 limit 为选填,可对未通过验证的回落连接限速,bytesPerSec 默认为 0 即不启用 # #回落限速是一种特征,不建议启用,如果您是面板/一键脚本开发者,务必让这些参数随机化 # limit-fallback-upload: diff --git a/listener/inbound/reality.go b/listener/inbound/reality.go index 7932b06029..ef419c2557 100644 --- a/listener/inbound/reality.go +++ b/listener/inbound/reality.go @@ -9,6 +9,7 @@ type RealityConfig struct { ServerNames []string `inbound:"server-names"` MaxTimeDifference int `inbound:"max-time-difference,omitempty"` Proxy string `inbound:"proxy,omitempty"` + ProxyProtocol int `inbound:"proxy-protocol,omitempty"` LimitFallbackUpload RealityLimitFallback `inbound:"limit-fallback-upload,omitempty"` LimitFallbackDownload RealityLimitFallback `inbound:"limit-fallback-download,omitempty"` @@ -28,6 +29,7 @@ func (c RealityConfig) Build() reality.Config { ServerNames: c.ServerNames, MaxTimeDifference: c.MaxTimeDifference, Proxy: c.Proxy, + ProxyProtocol: c.ProxyProtocol, LimitFallbackUpload: reality.LimitFallback{ AfterBytes: c.LimitFallbackUpload.AfterBytes, diff --git a/listener/reality/reality.go b/listener/reality/reality.go index 256dbcf6a9..66ed982e45 100644 --- a/listener/reality/reality.go +++ b/listener/reality/reality.go @@ -3,6 +3,7 @@ package reality import ( "context" "encoding/base64" + "encoding/binary" "encoding/hex" "errors" "fmt" @@ -29,12 +30,17 @@ type Config struct { ServerNames []string MaxTimeDifference int Proxy string + ProxyProtocol int LimitFallbackUpload LimitFallback LimitFallbackDownload LimitFallback } func (c Config) Build(tunnel C.Tunnel) (*Builder, error) { + if c.ProxyProtocol < 0 || c.ProxyProtocol > 2 { + return nil, fmt.Errorf("invalid proxy-protocol version: %d", c.ProxyProtocol) + } + realityConfig := &utls.RealityConfig{} realityConfig.SessionTicketsDisabled = true realityConfig.Type = "tcp" @@ -74,7 +80,15 @@ func (c Config) Build(tunnel C.Tunnel) (*Builder, error) { } realityConfig.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) { - return inner.HandleTcp(tunnel, address, c.Proxy) + target, err := inner.HandleTcp(tunnel, address, c.Proxy) + if err != nil { + return nil, err + } + if err := writeProxyProtocolHeader(ctx, target, c.ProxyProtocol); err != nil { + _ = target.Close() + return nil, err + } + return target, nil } realityConfig.LimitFallbackUpload = c.LimitFallbackUpload @@ -89,6 +103,8 @@ type Builder struct { func (b Builder) NewListener(l net.Listener) net.Listener { return N.NewHandleContextListener(context.Background(), l, func(ctx context.Context, conn net.Conn) (net.Conn, error) { + ctx = context.WithValue(ctx, sourceAddrContextKey{}, conn.RemoteAddr()) + ctx = context.WithValue(ctx, destinationAddrContextKey{}, conn.LocalAddr()) c, err := utls.RealityServer(ctx, conn, b.realityConfig) if err != nil { return nil, err @@ -121,3 +137,123 @@ func (c realityConnWrapper) ReaderReplaceable() bool { func (c realityConnWrapper) WriterReplaceable() bool { return true } + +type sourceAddrContextKey struct{} + +type destinationAddrContextKey struct{} + +func writeProxyProtocolHeader(ctx context.Context, conn net.Conn, version int) error { + if version == 0 { + return nil + } + + sourceAddr, destinationAddr, err := sourceAndDestinationAddrsFromContext(ctx) + if err != nil { + return err + } + + var payload []byte + switch version { + case 1: + payload = buildProxyProtocolV1Header(sourceAddr, destinationAddr) + case 2: + payload, err = buildProxyProtocolV2Header(sourceAddr, destinationAddr) + if err != nil { + return err + } + default: + return fmt.Errorf("invalid proxy-protocol version: %d", version) + } + + _, err = conn.Write(payload) + if err != nil { + return fmt.Errorf("write proxy-protocol header: %w", err) + } + return nil +} + +func sourceAndDestinationAddrsFromContext(ctx context.Context) (*net.TCPAddr, *net.TCPAddr, error) { + sourceAddr, ok := ctx.Value(sourceAddrContextKey{}).(net.Addr) + if !ok { + return nil, nil, errors.New("missing source address for proxy-protocol") + } + destinationAddr, ok := ctx.Value(destinationAddrContextKey{}).(net.Addr) + if !ok { + return nil, nil, errors.New("missing destination address for proxy-protocol") + } + + sourceTCPAddr, ok := sourceAddr.(*net.TCPAddr) + if !ok { + return nil, nil, fmt.Errorf("unsupported source address type for proxy-protocol: %T", sourceAddr) + } + destinationTCPAddr, ok := destinationAddr.(*net.TCPAddr) + if !ok { + return nil, nil, fmt.Errorf("unsupported destination address type for proxy-protocol: %T", destinationAddr) + } + + if sourceTCPAddr.IP == nil || destinationTCPAddr.IP == nil { + return nil, nil, errors.New("invalid source or destination IP for proxy-protocol") + } + + return sourceTCPAddr, destinationTCPAddr, nil +} + +func buildProxyProtocolV1Header(sourceAddr, destinationAddr *net.TCPAddr) []byte { + family := "TCP6" + sourceIP := sourceAddr.IP + destinationIP := destinationAddr.IP + if sourceIPv4 := sourceAddr.IP.To4(); sourceIPv4 != nil { + if destinationIPv4 := destinationAddr.IP.To4(); destinationIPv4 != nil { + family = "TCP4" + sourceIP = sourceIPv4 + destinationIP = destinationIPv4 + } + } + + return []byte(fmt.Sprintf("PROXY %s %s %s %d %d\r\n", family, sourceIP.String(), destinationIP.String(), sourceAddr.Port, destinationAddr.Port)) +} + +func buildProxyProtocolV2Header(sourceAddr, destinationAddr *net.TCPAddr) ([]byte, error) { + const ( + versionAndCommandProxy = 0x21 + familyAndProtocolTCPv4 = 0x11 + familyAndProtocolTCPv6 = 0x21 + ) + + signature := []byte{0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a} + header := make([]byte, 16) + copy(header, signature) + header[12] = versionAndCommandProxy + + if sourceIPv4 := sourceAddr.IP.To4(); sourceIPv4 != nil { + if destinationIPv4 := destinationAddr.IP.To4(); destinationIPv4 != nil { + header[13] = familyAndProtocolTCPv4 + binary.BigEndian.PutUint16(header[14:16], 12) + + payload := make([]byte, 12) + copy(payload[0:4], sourceIPv4) + copy(payload[4:8], destinationIPv4) + binary.BigEndian.PutUint16(payload[8:10], uint16(sourceAddr.Port)) + binary.BigEndian.PutUint16(payload[10:12], uint16(destinationAddr.Port)) + + return append(header, payload...), nil + } + } + + sourceIPv6 := sourceAddr.IP.To16() + destinationIPv6 := destinationAddr.IP.To16() + if sourceIPv6 == nil || destinationIPv6 == nil { + return nil, errors.New("invalid IP address for proxy-protocol v2") + } + + header[13] = familyAndProtocolTCPv6 + binary.BigEndian.PutUint16(header[14:16], 36) + + payload := make([]byte, 36) + copy(payload[0:16], sourceIPv6) + copy(payload[16:32], destinationIPv6) + binary.BigEndian.PutUint16(payload[32:34], uint16(sourceAddr.Port)) + binary.BigEndian.PutUint16(payload[34:36], uint16(destinationAddr.Port)) + + return append(header, payload...), nil +} diff --git a/listener/reality/reality_test.go b/listener/reality/reality_test.go new file mode 100644 index 0000000000..e7c86dbaa5 --- /dev/null +++ b/listener/reality/reality_test.go @@ -0,0 +1,144 @@ +package reality + +import ( + "bytes" + "context" + "encoding/binary" + "io" + "net" + "testing" + "time" +) + +func TestWriteProxyProtocolHeaderDisabled(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + + if err := writeProxyProtocolHeader(context.Background(), client, 0); err != nil { + t.Fatalf("writeProxyProtocolHeader() error = %v", err) + } + + if err := server.SetReadDeadline(time.Now().Add(30 * time.Millisecond)); err != nil { + t.Fatalf("SetReadDeadline() error = %v", err) + } + + buf := make([]byte, 1) + _, err := server.Read(buf) + if err == nil { + t.Fatal("expected no payload when proxy-protocol is disabled") + } + if ne, ok := err.(net.Error); !ok || !ne.Timeout() { + t.Fatalf("expected timeout read error, got %v", err) + } +} + +func TestWriteProxyProtocolHeaderV1(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + + expected := []byte("PROXY TCP4 192.0.2.1 198.51.100.2 12345 443\r\n") + payloadCh := make(chan []byte, 1) + errCh := make(chan error, 1) + go func() { + buf := make([]byte, len(expected)) + _, err := io.ReadFull(server, buf) + if err != nil { + errCh <- err + return + } + payloadCh <- buf + }() + + ctx := context.Background() + ctx = context.WithValue(ctx, sourceAddrContextKey{}, &net.TCPAddr{IP: net.ParseIP("192.0.2.1"), Port: 12345}) + ctx = context.WithValue(ctx, destinationAddrContextKey{}, &net.TCPAddr{IP: net.ParseIP("198.51.100.2"), Port: 443}) + + if err := writeProxyProtocolHeader(ctx, client, 1); err != nil { + t.Fatalf("writeProxyProtocolHeader() error = %v", err) + } + + select { + case err := <-errCh: + t.Fatalf("read header error: %v", err) + case payload := <-payloadCh: + if !bytes.Equal(payload, expected) { + t.Fatalf("unexpected v1 header: got %q want %q", payload, expected) + } + } +} + +func TestWriteProxyProtocolHeaderV2(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + + payloadCh := make(chan []byte, 1) + errCh := make(chan error, 1) + go func() { + buf := make([]byte, 28) + _, err := io.ReadFull(server, buf) + if err != nil { + errCh <- err + return + } + payloadCh <- buf + }() + + ctx := context.Background() + ctx = context.WithValue(ctx, sourceAddrContextKey{}, &net.TCPAddr{IP: net.ParseIP("192.0.2.10"), Port: 4567}) + ctx = context.WithValue(ctx, destinationAddrContextKey{}, &net.TCPAddr{IP: net.ParseIP("198.51.100.20"), Port: 8443}) + + if err := writeProxyProtocolHeader(ctx, client, 2); err != nil { + t.Fatalf("writeProxyProtocolHeader() error = %v", err) + } + + select { + case err := <-errCh: + t.Fatalf("read header error: %v", err) + case payload := <-payloadCh: + signature := []byte{0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a} + if !bytes.Equal(payload[:12], signature) { + t.Fatalf("invalid signature: %x", payload[:12]) + } + if payload[12] != 0x21 || payload[13] != 0x11 { + t.Fatalf("invalid version/family bytes: %x %x", payload[12], payload[13]) + } + if binary.BigEndian.Uint16(payload[14:16]) != 12 { + t.Fatalf("invalid address length: %d", binary.BigEndian.Uint16(payload[14:16])) + } + if binary.BigEndian.Uint16(payload[24:26]) != 4567 { + t.Fatalf("invalid source port: %d", binary.BigEndian.Uint16(payload[24:26])) + } + if binary.BigEndian.Uint16(payload[26:28]) != 8443 { + t.Fatalf("invalid destination port: %d", binary.BigEndian.Uint16(payload[26:28])) + } + } +} + +func TestWriteProxyProtocolHeaderMissingAddresses(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + + err := writeProxyProtocolHeader(context.Background(), client, 1) + if err == nil { + t.Fatal("expected error for missing addresses") + } +} + +func TestWriteProxyProtocolHeaderInvalidVersion(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + + ctx := context.Background() + ctx = context.WithValue(ctx, sourceAddrContextKey{}, &net.TCPAddr{IP: net.ParseIP("192.0.2.1"), Port: 12345}) + ctx = context.WithValue(ctx, destinationAddrContextKey{}, &net.TCPAddr{IP: net.ParseIP("198.51.100.2"), Port: 443}) + + err := writeProxyProtocolHeader(ctx, client, 3) + if err == nil { + t.Fatal("expected error for invalid version") + } +}