pax_global_header00006660000000000000000000000064151413705400014512gustar00rootroot0000000000000052 comment=d979e745f92994c5c475e3cb461b0fe26c1889c5 pires-go-proxyproto-04c9ad1/000077500000000000000000000000001514137054000161115ustar00rootroot00000000000000pires-go-proxyproto-04c9ad1/.github/000077500000000000000000000000001514137054000174515ustar00rootroot00000000000000pires-go-proxyproto-04c9ad1/.github/FUNDING.yml000066400000000000000000000000161514137054000212630ustar00rootroot00000000000000github: pires pires-go-proxyproto-04c9ad1/.github/workflows/000077500000000000000000000000001514137054000215065ustar00rootroot00000000000000pires-go-proxyproto-04c9ad1/.github/workflows/golangci-lint.yml000066400000000000000000000034371514137054000247670ustar00rootroot00000000000000name: golangci-lint on: push: tags: - v* branches: - main pull_request: permissions: # Required: allow read access to the content for analysis. contents: read # Optional: allow read access to pull request. Use with `only-new-issues` option. pull-requests: read # Optional: allow write access to checks to allow the action to annotate code in the PR. checks: write jobs: golangci: name: lint runs-on: ubuntu-latest env: GOTOOLCHAIN: local strategy: matrix: go: ['1.24', '1.25'] steps: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - name: Tidy run: go mod tidy - name: Format run: go fmt - name: Vet run: go vet - name: lint uses: golangci/golangci-lint-action@v9 #with: # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version #version: v1.29 # Optional: working directory, useful for monorepos # working-directory: somedir # Optional: golangci-lint command line arguments. # args: --issues-exit-code=0 # Optional: show only new issues if it's a pull request. The default value is `false`. # only-new-issues: true # Optional: if set to true then the all caching functionality will be complete disabled, # takes precedence over all other caching options. # skip-cache: true # Optional: if set to true then the action don't cache or restore ~/go/pkg. # skip-pkg-cache: true # Optional: if set to true then the action don't cache or restore ~/.cache/go-build. # skip-build-cache: true pires-go-proxyproto-04c9ad1/.github/workflows/release.yml000066400000000000000000000004101514137054000236440ustar00rootroot00000000000000name: release on: push: tags: - "v*.*.*" jobs: release: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Release uses: softprops/action-gh-release@v2 with: generate_release_notes: true pires-go-proxyproto-04c9ad1/.github/workflows/test.yml000066400000000000000000000017121514137054000232110ustar00rootroot00000000000000name: test on: pull_request: jobs: test: runs-on: ubuntu-latest env: GOTOOLCHAIN: local strategy: fail-fast: false matrix: go: ['1.24', '1.25'] steps: - uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - uses: actions/checkout@v4 - name: Get dependencies run: | go get golang.org/x/tools/cmd/cover go get github.com/mattn/goveralls - name: Test run: go test -race -v -covermode=atomic -coverprofile=coverage.out - name: Send coverage uses: shogo82148/actions-goveralls@v1 with: github-token: ${{ secrets.GITHUB_TOKEN }} path-to-profile: coverage.out flag-name: Go-${{ matrix.go }} parallel: true # notifies that all test jobs are finished. finish: needs: test runs-on: ubuntu-latest steps: - uses: shogo82148/actions-goveralls@v1 with: parallel-finished: true pires-go-proxyproto-04c9ad1/.gitignore000066400000000000000000000001571514137054000201040ustar00rootroot00000000000000# Compiled Object files, Static and Dynamic libs (Shared Objects) *.o *.a *.so # Folders .idea bin pkg *.out pires-go-proxyproto-04c9ad1/.golangci.yml000066400000000000000000000005741514137054000205030ustar00rootroot00000000000000version: "2" linters: default: standard enable: - asasalint - asciicheck - bidichk - bodyclose - canonicalheader - containedctx - copyloopvar - goconst - godot - gosec - modernize - misspell - revive - unconvert - usestdlibvars run: timeout: 5m allow-parallel-runners: true issues: max-issues-per-linter: 0 pires-go-proxyproto-04c9ad1/LICENSE000066400000000000000000000261151514137054000171230ustar00rootroot00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "{}" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2016 Paulo Pires Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. pires-go-proxyproto-04c9ad1/README.md000066400000000000000000000100311514137054000173630ustar00rootroot00000000000000# go-proxyproto [![Actions Status](https://github.com/pires/go-proxyproto/workflows/test/badge.svg)](https://github.com/pires/go-proxyproto/actions) [![Coverage Status](https://coveralls.io/repos/github/pires/go-proxyproto/badge.svg?branch=main)](https://coveralls.io/github/pires/go-proxyproto?branch=main) [![Go Report Card](https://goreportcard.com/badge/github.com/pires/go-proxyproto)](https://goreportcard.com/report/github.com/pires/go-proxyproto) [![](https://godoc.org/github.com/pires/go-proxyproto?status.svg)](https://pkg.go.dev/github.com/pires/go-proxyproto?tab=doc) A Go library implementation of the [PROXY protocol, versions 1 and 2](https://www.haproxy.org/download/2.3/doc/proxy-protocol.txt), which provides, as per specification: > (...) a convenient way to safely transport connection > information such as a client's address across multiple layers of NAT or TCP > proxies. It is designed to require little changes to existing components and > to limit the performance impact caused by the processing of the transported > information. This library is to be used in one of or both proxy clients and proxy servers that need to support said protocol. Both protocol versions, 1 (text-based) and 2 (binary-based) are supported. ## Installation ```shell $ go get -u github.com/pires/go-proxyproto ``` ## Usage ### Client ```go package main import ( "io" "log" "net" proxyproto "github.com/pires/go-proxyproto" ) func chkErr(err error) { if err != nil { log.Fatalf("Error: %s", err.Error()) } } func main() { // Dial some proxy listener e.g. https://github.com/mailgun/proxyproto target, err := net.ResolveTCPAddr("tcp", "127.0.0.1:2319") chkErr(err) conn, err := net.DialTCP("tcp", nil, target) chkErr(err) defer conn.Close() // Create a proxyprotocol header or use HeaderProxyFromAddrs() if you // have two conn's header := &proxyproto.Header{ Version: 1, Command: proxyproto.PROXY, TransportProtocol: proxyproto.TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } // After the connection was created write the proxy headers first _, err = header.WriteTo(conn) chkErr(err) // Then your data... e.g.: _, err = io.WriteString(conn, "HELO") chkErr(err) } ``` ### Server ```go package main import ( "log" "net" proxyproto "github.com/pires/go-proxyproto" ) func main() { // Create a listener addr := "localhost:9876" list, err := net.Listen("tcp", addr) if err != nil { log.Fatalf("couldn't listen to %q: %q\n", addr, err.Error()) } // Wrap listener in a proxyproto listener proxyListener := &proxyproto.Listener{Listener: list} defer proxyListener.Close() // Wait for a connection and accept it conn, err := proxyListener.Accept() defer conn.Close() // Print connection details if conn.LocalAddr() == nil { log.Fatal("couldn't retrieve local address") } log.Printf("local address: %q", conn.LocalAddr().String()) if conn.RemoteAddr() == nil { log.Fatal("couldn't retrieve remote address") } log.Printf("remote address: %q", conn.RemoteAddr().String()) } ``` ### HTTP Server ```go package main import ( "net" "net/http" "time" "github.com/pires/go-proxyproto" ) func main() { server := http.Server{ Addr: ":8080", } ln, err := net.Listen("tcp", server.Addr) if err != nil { panic(err) } proxyListener := &proxyproto.Listener{ Listener: ln, ReadHeaderTimeout: 10 * time.Second, } defer proxyListener.Close() server.Serve(proxyListener) } ``` ## Special notes ### AWS AWS Network Load Balancer (NLB) does not push the PPV2 header until the client starts sending the data. This is a problem if your server speaks first. e.g. SMTP, FTP, SSH etc. By default, NLB target group attribute `proxy_protocol_v2.client_to_server.header_placement` has the value `on_first_ack_with_payload`. You need to contact AWS support to change it to `on_first_ack`, instead. Just to be clear, you need this fix only if your server is designed to speak first. pires-go-proxyproto-04c9ad1/addr_proto.go000066400000000000000000000037631514137054000206060ustar00rootroot00000000000000package proxyproto // AddressFamilyAndProtocol represents address family and transport protocol. type AddressFamilyAndProtocol byte // AddressFamilyAndProtocol enum values. const ( UNSPEC AddressFamilyAndProtocol = '\x00' TCPv4 AddressFamilyAndProtocol = '\x11' UDPv4 AddressFamilyAndProtocol = '\x12' TCPv6 AddressFamilyAndProtocol = '\x21' UDPv6 AddressFamilyAndProtocol = '\x22' UnixStream AddressFamilyAndProtocol = '\x31' UnixDatagram AddressFamilyAndProtocol = '\x32' ) // IsIPv4 returns true if the address family is IPv4 (AF_INET4), false otherwise. func (ap AddressFamilyAndProtocol) IsIPv4() bool { return ap&0xF0 == 0x10 } // IsIPv6 returns true if the address family is IPv6 (AF_INET6), false otherwise. func (ap AddressFamilyAndProtocol) IsIPv6() bool { return ap&0xF0 == 0x20 } // IsUnix returns true if the address family is UNIX (AF_UNIX), false otherwise. func (ap AddressFamilyAndProtocol) IsUnix() bool { return ap&0xF0 == 0x30 } // IsStream returns true if the transport protocol is TCP or STREAM (SOCK_STREAM), false otherwise. func (ap AddressFamilyAndProtocol) IsStream() bool { return ap&0x0F == 0x01 } // IsDatagram returns true if the transport protocol is UDP or DGRAM (SOCK_DGRAM), false otherwise. func (ap AddressFamilyAndProtocol) IsDatagram() bool { return ap&0x0F == 0x02 } // IsUnspec returns true if the transport protocol or address family is unspecified, false otherwise. func (ap AddressFamilyAndProtocol) IsUnspec() bool { return (ap&0xF0 == 0x00) || (ap&0x0F == 0x00) } func (ap AddressFamilyAndProtocol) toByte() byte { if ap.IsIPv4() && ap.IsStream() { return byte(TCPv4) } else if ap.IsIPv4() && ap.IsDatagram() { return byte(UDPv4) } else if ap.IsIPv6() && ap.IsStream() { return byte(TCPv6) } else if ap.IsIPv6() && ap.IsDatagram() { return byte(UDPv6) } else if ap.IsUnix() && ap.IsStream() { return byte(UnixStream) } else if ap.IsUnix() && ap.IsDatagram() { return byte(UnixDatagram) } return byte(UNSPEC) } pires-go-proxyproto-04c9ad1/addr_proto_test.go000066400000000000000000000032311514137054000216330ustar00rootroot00000000000000package proxyproto import ( "testing" ) func TestTCPoverIPv4(t *testing.T) { b := byte(TCPv4) if !AddressFamilyAndProtocol(b).IsIPv4() { t.Fail() } if !AddressFamilyAndProtocol(b).IsStream() { t.Fail() } if AddressFamilyAndProtocol(b).toByte() != b { t.Fail() } } func TestTCPoverIPv6(t *testing.T) { b := byte(TCPv6) if !AddressFamilyAndProtocol(b).IsIPv6() { t.Fail() } if !AddressFamilyAndProtocol(b).IsStream() { t.Fail() } if AddressFamilyAndProtocol(b).toByte() != b { t.Fail() } } func TestUDPoverIPv4(t *testing.T) { b := byte(UDPv4) if !AddressFamilyAndProtocol(b).IsIPv4() { t.Fail() } if !AddressFamilyAndProtocol(b).IsDatagram() { t.Fail() } if AddressFamilyAndProtocol(b).toByte() != b { t.Fail() } } func TestUDPoverIPv6(t *testing.T) { b := byte(UDPv6) if !AddressFamilyAndProtocol(b).IsIPv6() { t.Fail() } if !AddressFamilyAndProtocol(b).IsDatagram() { t.Fail() } if AddressFamilyAndProtocol(b).toByte() != b { t.Fail() } } func TestUnixStream(t *testing.T) { b := byte(UnixStream) if !AddressFamilyAndProtocol(b).IsUnix() { t.Fail() } if !AddressFamilyAndProtocol(b).IsStream() { t.Fail() } if AddressFamilyAndProtocol(b).toByte() != b { t.Fail() } } func TestUnixDatagram(t *testing.T) { b := byte(UnixDatagram) if !AddressFamilyAndProtocol(b).IsUnix() { t.Fail() } if !AddressFamilyAndProtocol(b).IsDatagram() { t.Fail() } if AddressFamilyAndProtocol(b).toByte() != b { t.Fail() } } func TestInvalidAddressFamilyAndProtocol(t *testing.T) { b := byte(UNSPEC) if !AddressFamilyAndProtocol(b).IsUnspec() { t.Fail() } if AddressFamilyAndProtocol(b).toByte() != b { t.Fail() } } pires-go-proxyproto-04c9ad1/example_conn_test.go000066400000000000000000000040361514137054000221520ustar00rootroot00000000000000package proxyproto_test import ( "net" "time" "github.com/pires/go-proxyproto" ) func ExampleNewConn_default() { serverConn, clientConn := net.Pipe() defer func() { _ = serverConn.Close() }() defer func() { _ = clientConn.Close() }() go func() { _, _ = clientConn.Write([]byte("x")) _ = clientConn.Close() }() conn := proxyproto.NewConn(serverConn) buf := make([]byte, 1) _, _ = conn.Read(buf) // Output: } func ExampleNewConn_withBufferSize() { serverConn, clientConn := net.Pipe() defer func() { _ = serverConn.Close() }() defer func() { _ = clientConn.Close() }() go func() { _, _ = clientConn.Write([]byte("y")) _ = clientConn.Close() }() conn := proxyproto.NewConn(serverConn, proxyproto.WithBufferSize(4096)) buf := make([]byte, 1) _, _ = conn.Read(buf) // Output: } func ExampleNewConn_withReadHeaderTimeout() { serverConn, clientConn := net.Pipe() defer func() { _ = serverConn.Close() }() defer func() { _ = clientConn.Close() }() go func() { _, _ = clientConn.Write([]byte("z")) _ = clientConn.Close() }() conn := proxyproto.NewConn(serverConn, proxyproto.SetReadHeaderTimeout(time.Second)) buf := make([]byte, 1) _, _ = conn.Read(buf) // Output: } func ExampleNewConn_withPolicy() { serverConn, clientConn := net.Pipe() defer func() { _ = serverConn.Close() }() defer func() { _ = clientConn.Close() }() go func() { _, _ = clientConn.Write([]byte(proxyV1Line)) _, _ = clientConn.Write([]byte("p")) _ = clientConn.Close() }() conn := proxyproto.NewConn(serverConn, proxyproto.WithPolicy(proxyproto.REQUIRE)) buf := make([]byte, 1) _, _ = conn.Read(buf) // Output: } func ExampleNewConn_combined() { serverConn, clientConn := net.Pipe() defer func() { _ = serverConn.Close() }() defer func() { _ = clientConn.Close() }() go func() { _, _ = clientConn.Write([]byte("c")) _ = clientConn.Close() }() conn := proxyproto.NewConn(serverConn, proxyproto.WithBufferSize(2048), proxyproto.SetReadHeaderTimeout(2*time.Second), ) buf := make([]byte, 1) _, _ = conn.Read(buf) // Output: } pires-go-proxyproto-04c9ad1/example_listener_test.go000066400000000000000000000052341514137054000230430ustar00rootroot00000000000000package proxyproto_test import ( "net" "time" "github.com/pires/go-proxyproto" ) // proxyV1Line is a minimal PROXY protocol v1 header for examples. const proxyV1Line = "PROXY TCP4 192.168.1.1 192.168.1.2 12345 443\r\n" func ExampleListener_default() { l, _ := net.Listen("tcp", "127.0.0.1:0") pl := &proxyproto.Listener{Listener: l} defer func() { _ = pl.Close() }() go func() { c, _ := net.Dial("tcp", pl.Addr().String()) if c != nil { _, _ = c.Write([]byte("x")) _ = c.Close() } }() conn, _ := pl.Accept() if conn != nil { buf := make([]byte, 1) _, _ = conn.Read(buf) _ = conn.Close() } // Output: } func ExampleListener_readHeaderTimeout() { l, _ := net.Listen("tcp", "127.0.0.1:0") pl := &proxyproto.Listener{ Listener: l, ReadHeaderTimeout: 2 * time.Second, } defer func() { _ = pl.Close() }() go func() { c, _ := net.Dial("tcp", pl.Addr().String()) if c != nil { _, _ = c.Write([]byte("a")) _ = c.Close() } }() conn, _ := pl.Accept() if conn != nil { _ = conn.SetReadDeadline(time.Now().Add(time.Second)) buf := make([]byte, 1) _, _ = conn.Read(buf) _ = conn.Close() } // Output: } func ExampleListener_readBufferSize() { l, _ := net.Listen("tcp", "127.0.0.1:0") pl := &proxyproto.Listener{ Listener: l, ReadBufferSize: 4096, } defer func() { _ = pl.Close() }() go func() { c, _ := net.Dial("tcp", pl.Addr().String()) if c != nil { _, _ = c.Write([]byte("b")) _ = c.Close() } }() conn, _ := pl.Accept() if conn != nil { buf := make([]byte, 1) _, _ = conn.Read(buf) _ = conn.Close() } // Output: } func ExampleListener_policyRequire() { l, _ := net.Listen("tcp", "127.0.0.1:0") pl := &proxyproto.Listener{ Listener: l, Policy: func(net.Addr) (proxyproto.Policy, error) { return proxyproto.REQUIRE, nil }, } defer func() { _ = pl.Close() }() go func() { c, _ := net.Dial("tcp", pl.Addr().String()) if c != nil { _, _ = c.Write([]byte(proxyV1Line)) _, _ = c.Write([]byte("p")) _ = c.Close() } }() conn, _ := pl.Accept() if conn != nil { buf := make([]byte, 1) _, _ = conn.Read(buf) _ = conn.Close() } // Output: } func ExampleListener_validateHeader() { l, _ := net.Listen("tcp", "127.0.0.1:0") pl := &proxyproto.Listener{ Listener: l, ValidateHeader: func(*proxyproto.Header) error { return nil }, } defer func() { _ = pl.Close() }() go func() { c, _ := net.Dial("tcp", pl.Addr().String()) if c != nil { _, _ = c.Write([]byte(proxyV1Line)) _, _ = c.Write([]byte("v")) _ = c.Close() } }() conn, _ := pl.Accept() if conn != nil { buf := make([]byte, 1) _, _ = conn.Read(buf) _ = conn.Close() } // Output: } pires-go-proxyproto-04c9ad1/examples/000077500000000000000000000000001514137054000177275ustar00rootroot00000000000000pires-go-proxyproto-04c9ad1/examples/client/000077500000000000000000000000001514137054000212055ustar00rootroot00000000000000pires-go-proxyproto-04c9ad1/examples/client/client.go000066400000000000000000000020731514137054000230140ustar00rootroot00000000000000// Package main provides a proxyproto client example. package main import ( "io" "log" "net" proxyproto "github.com/pires/go-proxyproto" ) func chkErr(err error) { if err != nil { log.Fatalf("Error: %s", err.Error()) } } func main() { // Dial some proxy listener e.g. https://github.com/mailgun/proxyproto target, err := net.ResolveTCPAddr("tcp", "127.0.0.1:9876") chkErr(err) conn, err := net.DialTCP("tcp", nil, target) chkErr(err) defer func() { _ = conn.Close() }() // Create a proxyprotocol header or use HeaderProxyFromAddrs() if you // have two conn's header := &proxyproto.Header{ Version: 1, Command: proxyproto.PROXY, TransportProtocol: proxyproto.TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } // After the connection was created write the proxy headers first _, err = header.WriteTo(conn) chkErr(err) // Then your data... e.g.: _, err = io.WriteString(conn, "HELO") chkErr(err) } pires-go-proxyproto-04c9ad1/examples/httpserver/000077500000000000000000000000001514137054000221355ustar00rootroot00000000000000pires-go-proxyproto-04c9ad1/examples/httpserver/httpserver.go000066400000000000000000000024251514137054000246750ustar00rootroot00000000000000// Package main provides a proxyproto HTTP server example. package main import ( "log" "net" "net/http" "time" "github.com/pires/go-proxyproto" h2proxy "github.com/pires/go-proxyproto/helper/http2" ) // TODO: add httpclient example func main() { server := http.Server{ Addr: ":8080", ReadHeaderTimeout: 5 * time.Second, ConnState: func(c net.Conn, s http.ConnState) { if s == http.StateNew { log.Printf("[ConnState] %s -> %s", c.LocalAddr().String(), c.RemoteAddr().String()) } }, Handler: http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { log.Printf("[Handler] remote ip %q", r.RemoteAddr) }), } ln, err := net.Listen("tcp", server.Addr) if err != nil { panic(err) } proxyListener := &proxyproto.Listener{ Listener: ln, ReadHeaderTimeout: 10 * time.Second, } defer func() { if err := proxyListener.Close(); err != nil { log.Printf("failed to close proxy listener: %v", err) } }() // Create an HTTP server which can handle proxied incoming connections for // both HTTP/1 and HTTP/2. HTTP/2 support relies on TLS ALPN, the reverse // proxy needs to be configured to accept "h2". if err := h2proxy.NewServer(&server, nil).Serve(proxyListener); err != nil { log.Fatalf("failed to serve: %v", err) } } pires-go-proxyproto-04c9ad1/examples/server/000077500000000000000000000000001514137054000212355ustar00rootroot00000000000000pires-go-proxyproto-04c9ad1/examples/server/server.go000066400000000000000000000021611514137054000230720ustar00rootroot00000000000000// Package main provides a proxyproto server example. package main import ( "log" "net" proxyproto "github.com/pires/go-proxyproto" ) func main() { // Create a listener addr := "localhost:9876" list, err := net.Listen("tcp", addr) if err != nil { log.Fatalf("couldn't listen to %q: %q\n", addr, err.Error()) } // Wrap listener in a proxyproto listener proxyListener := &proxyproto.Listener{Listener: list} defer func() { if err := proxyListener.Close(); err != nil { log.Printf("failed to close proxy listener: %v", err) } }() // Wait for a connection and accept it conn, err := proxyListener.Accept() if err != nil { log.Fatalf("failed to accept connection: %v", err) } defer func() { if err := conn.Close(); err != nil { log.Printf("failed to close connection: %v", err) } }() // Print connection details if conn.LocalAddr() == nil { log.Fatal("couldn't retrieve local address") } log.Printf("local address: %q", conn.LocalAddr().String()) if conn.RemoteAddr() == nil { log.Fatal("couldn't retrieve remote address") } log.Printf("remote address: %q", conn.RemoteAddr().String()) } pires-go-proxyproto-04c9ad1/go.mod000066400000000000000000000002001514137054000172070ustar00rootroot00000000000000module github.com/pires/go-proxyproto go 1.24 require golang.org/x/net v0.39.0 require golang.org/x/text v0.24.0 // indirect pires-go-proxyproto-04c9ad1/go.sum000066400000000000000000000004641514137054000172500ustar00rootroot00000000000000golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= pires-go-proxyproto-04c9ad1/header.go000066400000000000000000000247461514137054000177050ustar00rootroot00000000000000// Package proxyproto implements Proxy Protocol (v1 and v2) parser and writer, as per specification: // https://www.haproxy.org/download/2.3/doc/proxy-protocol.txt package proxyproto import ( "bufio" "bytes" "errors" "io" "net" "time" ) var ( // SIGV1 is the signature for PROXY protocol v1. SIGV1 = []byte{'\x50', '\x52', '\x4F', '\x58', '\x59'} // SIGV2 is the signature for PROXY protocol v2. SIGV2 = []byte{'\x0D', '\x0A', '\x0D', '\x0A', '\x00', '\x0D', '\x0A', '\x51', '\x55', '\x49', '\x54', '\x0A'} // ErrCantReadVersion1Header indicates a v1 header could not be read. ErrCantReadVersion1Header = errors.New("proxyproto: can't read version 1 header") // ErrVersion1HeaderTooLong indicates a v1 header is too long. ErrVersion1HeaderTooLong = errors.New("proxyproto: version 1 header must be 107 bytes or less") // ErrLineMustEndWithCrlf indicates a v1 header is invalid, must end with \r\n. ErrLineMustEndWithCrlf = errors.New("proxyproto: version 1 header is invalid, must end with \\r\\n") // ErrCantReadProtocolVersionAndCommand indicates a protocol version and command could not be read. ErrCantReadProtocolVersionAndCommand = errors.New("proxyproto: can't read proxy protocol version and command") // ErrCantReadAddressFamilyAndProtocol indicates an address family and protocol could not be read. ErrCantReadAddressFamilyAndProtocol = errors.New("proxyproto: can't read address family or protocol") // ErrCantReadLength indicates a length could not be read. ErrCantReadLength = errors.New("proxyproto: can't read length") // ErrCantResolveSourceUnixAddress indicates a source Unix address could not be resolved. ErrCantResolveSourceUnixAddress = errors.New("proxyproto: can't resolve source Unix address") // ErrCantResolveDestinationUnixAddress indicates a destination Unix address could not be resolved. ErrCantResolveDestinationUnixAddress = errors.New("proxyproto: can't resolve destination Unix address") // ErrNoProxyProtocol indicates a proxy protocol signature is not present. ErrNoProxyProtocol = errors.New("proxyproto: proxy protocol signature not present") // ErrUnknownProxyProtocolVersion indicates an unknown proxy protocol version. ErrUnknownProxyProtocolVersion = errors.New("proxyproto: unknown proxy protocol version") // ErrUnsupportedProtocolVersionAndCommand indicates an unsupported protocol version and command. ErrUnsupportedProtocolVersionAndCommand = errors.New("proxyproto: unsupported proxy protocol version and command") // ErrUnsupportedAddressFamilyAndProtocol indicates an unsupported address family and protocol. ErrUnsupportedAddressFamilyAndProtocol = errors.New("proxyproto: unsupported address family and protocol") // ErrInvalidLength indicates an invalid length. ErrInvalidLength = errors.New("proxyproto: invalid length") // ErrInvalidAddress indicates an invalid address. ErrInvalidAddress = errors.New("proxyproto: invalid address") // ErrInvalidPortNumber indicates an invalid port number. ErrInvalidPortNumber = errors.New("proxyproto: invalid port number") // ErrSuperfluousProxyHeader indicates an upstream connection sent a PROXY header but isn't allowed to send one. ErrSuperfluousProxyHeader = errors.New("proxyproto: upstream connection sent PROXY header but isn't allowed to send one") ) // Header is the placeholder for proxy protocol header. type Header struct { Version byte Command ProtocolVersionAndCommand TransportProtocol AddressFamilyAndProtocol SourceAddr net.Addr DestinationAddr net.Addr rawTLVs []byte } // HeaderProxyFromAddrs creates a new PROXY header from a source and a // destination address. If version is zero, the latest protocol version is // used. // // The header is filled on a best-effort basis: if hints cannot be inferred // from the provided addresses, the header will be left unspecified. func HeaderProxyFromAddrs(version byte, sourceAddr, destAddr net.Addr) *Header { if version < 1 || version > 2 { version = 2 } h := &Header{ Version: version, Command: LOCAL, TransportProtocol: UNSPEC, } switch sourceAddr := sourceAddr.(type) { case *net.TCPAddr: if _, ok := destAddr.(*net.TCPAddr); !ok { break } if len(sourceAddr.IP.To4()) == net.IPv4len { h.TransportProtocol = TCPv4 } else if len(sourceAddr.IP) == net.IPv6len { h.TransportProtocol = TCPv6 } case *net.UDPAddr: if _, ok := destAddr.(*net.UDPAddr); !ok { break } if len(sourceAddr.IP.To4()) == net.IPv4len { h.TransportProtocol = UDPv4 } else if len(sourceAddr.IP) == net.IPv6len { h.TransportProtocol = UDPv6 } case *net.UnixAddr: if _, ok := destAddr.(*net.UnixAddr); !ok { break } switch sourceAddr.Net { case "unix": h.TransportProtocol = UnixStream case "unixgram": h.TransportProtocol = UnixDatagram } } if h.TransportProtocol != UNSPEC { h.Command = PROXY h.SourceAddr = sourceAddr h.DestinationAddr = destAddr } return h } // TCPAddrs returns TCP source/destination addresses if the header is stream-based. func (header *Header) TCPAddrs() (sourceAddr, destAddr *net.TCPAddr, ok bool) { if !header.TransportProtocol.IsStream() { return nil, nil, false } sourceAddr, sourceOK := header.SourceAddr.(*net.TCPAddr) destAddr, destOK := header.DestinationAddr.(*net.TCPAddr) return sourceAddr, destAddr, sourceOK && destOK } // UDPAddrs returns UDP source/destination addresses if the header is datagram-based. func (header *Header) UDPAddrs() (sourceAddr, destAddr *net.UDPAddr, ok bool) { if !header.TransportProtocol.IsDatagram() { return nil, nil, false } sourceAddr, sourceOK := header.SourceAddr.(*net.UDPAddr) destAddr, destOK := header.DestinationAddr.(*net.UDPAddr) return sourceAddr, destAddr, sourceOK && destOK } // UnixAddrs returns UNIX source/destination addresses if the header is UNIX-based. func (header *Header) UnixAddrs() (sourceAddr, destAddr *net.UnixAddr, ok bool) { if !header.TransportProtocol.IsUnix() { return nil, nil, false } sourceAddr, sourceOK := header.SourceAddr.(*net.UnixAddr) destAddr, destOK := header.DestinationAddr.(*net.UnixAddr) return sourceAddr, destAddr, sourceOK && destOK } // IPs returns source/destination IPs for TCP/UDP headers. func (header *Header) IPs() (sourceIP, destIP net.IP, ok bool) { if sourceAddr, destAddr, ok := header.TCPAddrs(); ok { return sourceAddr.IP, destAddr.IP, true } if sourceAddr, destAddr, ok := header.UDPAddrs(); ok { return sourceAddr.IP, destAddr.IP, true } return nil, nil, false } // Ports returns source/destination ports for TCP/UDP headers. func (header *Header) Ports() (sourcePort, destPort int, ok bool) { if sourceAddr, destAddr, ok := header.TCPAddrs(); ok { return sourceAddr.Port, destAddr.Port, true } if sourceAddr, destAddr, ok := header.UDPAddrs(); ok { return sourceAddr.Port, destAddr.Port, true } return 0, 0, false } // EqualTo returns true if headers are equivalent, false otherwise. // Deprecated: use EqualsTo instead. This method will eventually be removed. func (header *Header) EqualTo(otherHeader *Header) bool { return header.EqualsTo(otherHeader) } // EqualsTo returns true if headers are equivalent, false otherwise. func (header *Header) EqualsTo(otherHeader *Header) bool { if otherHeader == nil { return false } if header.Version != otherHeader.Version || header.Command != otherHeader.Command || header.TransportProtocol != otherHeader.TransportProtocol { return false } // TLVs only exist for version 2 if header.Version == 2 && !bytes.Equal(header.rawTLVs, otherHeader.rawTLVs) { return false } // Return early for header with LOCAL command, which contains no address information if header.Command == LOCAL { return true } return header.SourceAddr.String() == otherHeader.SourceAddr.String() && header.DestinationAddr.String() == otherHeader.DestinationAddr.String() } // WriteTo renders a proxy protocol header in a format and writes it to an io.Writer. func (header *Header) WriteTo(w io.Writer) (int64, error) { buf, err := header.Format() if err != nil { return 0, err } return bytes.NewBuffer(buf).WriteTo(w) } // Format renders a proxy protocol header in a format to write over the wire. func (header *Header) Format() ([]byte, error) { switch header.Version { case 1: return header.formatVersion1() case 2: return header.formatVersion2() default: return nil, ErrUnknownProxyProtocolVersion } } // TLVs returns the TLVs stored into this header, if they exist. TLVs are optional for v2 of the protocol. func (header *Header) TLVs() ([]TLV, error) { return SplitTLVs(header.rawTLVs) } // SetTLVs sets the TLVs stored in this header. This method replaces any // previous TLV. func (header *Header) SetTLVs(tlvs []TLV) error { raw, err := JoinTLVs(tlvs) if err != nil { return err } header.rawTLVs = raw return nil } // Read identifies the proxy protocol version and reads the remaining of // the header, accordingly. // // If proxy protocol header signature is not present, the reader buffer remains untouched // and is safe for reading outside of this code. // // If proxy protocol header signature is present but an error is raised while processing // the remaining header, assume the reader buffer to be in a corrupt state. // Also, this operation will block until enough bytes are available for peeking. func Read(reader *bufio.Reader) (*Header, error) { // In order to improve speed for small non-PROXYed packets, take a peek at the first byte alone. b1, err := reader.Peek(1) if err != nil { if err == io.EOF { return nil, ErrNoProxyProtocol } return nil, err } if bytes.Equal(b1[:1], SIGV1[:1]) || bytes.Equal(b1[:1], SIGV2[:1]) { signature, err := reader.Peek(5) if err != nil { if err == io.EOF { return nil, ErrNoProxyProtocol } return nil, err } if bytes.Equal(signature[:5], SIGV1) { return parseVersion1(reader) } signature, err = reader.Peek(12) if err != nil { if err == io.EOF { return nil, ErrNoProxyProtocol } return nil, err } if bytes.Equal(signature[:12], SIGV2) { return parseVersion2(reader) } } return nil, ErrNoProxyProtocol } // ReadTimeout acts as Read but takes a timeout. If that timeout is reached, it's assumed // there's no proxy protocol header. func ReadTimeout(reader *bufio.Reader, timeout time.Duration) (*Header, error) { type header struct { h *Header e error } read := make(chan *header, 1) go func() { h := &header{} h.h, h.e = Read(reader) read <- h }() timer := time.NewTimer(timeout) select { case result := <-read: timer.Stop() return result.h, result.e case <-timer.C: return nil, ErrNoProxyProtocol } } pires-go-proxyproto-04c9ad1/header_test.go000066400000000000000000000436231514137054000207370ustar00rootroot00000000000000package proxyproto import ( "bufio" "bytes" "errors" "net" "reflect" "testing" "time" ) // Stuff to be used in both versions tests. const ( testNoProtocol = "There is no spoon" testLocalhostIP4Addr = "127.0.0.1" testLocalhostIP4In6Addr = "::ffff:127.0.0.1" testLocalhostIP6Addr = "::1" testIP6LongAddr = "1234:5678:9abc:def0:cafe:babe:dead:2bad" testValidPort = 65533 testInvalidPort = 99999 ) var ( v4ip = net.ParseIP(testLocalhostIP4Addr).To4() v6ip = net.ParseIP(testLocalhostIP6Addr).To16() v4addr net.Addr = &net.TCPAddr{IP: v4ip, Port: testValidPort} v6addr net.Addr = &net.TCPAddr{IP: v6ip, Port: testValidPort} v4UDPAddr net.Addr = &net.UDPAddr{IP: v4ip, Port: testValidPort} v6UDPAddr net.Addr = &net.UDPAddr{IP: v6ip, Port: testValidPort} unixStreamAddr net.Addr = &net.UnixAddr{Net: "unix", Name: "socket"} unixDatagramAddr net.Addr = &net.UnixAddr{Net: "unixgram", Name: "socket"} errReadIntentionallyBroken = errors.New("read is intentionally broken") ) type timeoutReader []byte func (t *timeoutReader) Read([]byte) (int, error) { time.Sleep(500 * time.Millisecond) return 0, nil } type errorReader []byte func (e *errorReader) Read([]byte) (int, error) { return 0, errReadIntentionallyBroken } func TestReadTimeoutV1Invalid(t *testing.T) { var b timeoutReader reader := bufio.NewReader(&b) _, err := ReadTimeout(reader, 50*time.Millisecond) if err == nil { t.Fatalf("expected error %s", ErrNoProxyProtocol) } else if err != ErrNoProxyProtocol { t.Fatalf("expected %s, actual %s", ErrNoProxyProtocol, err) } } func TestReadTimeoutPropagatesReadError(t *testing.T) { var e errorReader reader := bufio.NewReader(&e) _, err := ReadTimeout(reader, 50*time.Millisecond) if err == nil { t.Fatalf("expected error %s", errReadIntentionallyBroken) } else if err != errReadIntentionallyBroken { t.Fatalf("expected error %s, actual %s", errReadIntentionallyBroken, err) } } func TestEqualsTo(t *testing.T) { var headersEqual = []struct { this, that *Header expected bool }{ { &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, }, nil, false, }, { &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, }, &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, }, false, }, { &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, }, &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, }, true, }, } for _, tt := range headersEqual { if actual := tt.this.EqualsTo(tt.that); actual != tt.expected { t.Fatalf("expected %t, actual %t", tt.expected, actual) } } } // This is here just because of coveralls. func TestEqualTo(t *testing.T) { TestEqualsTo(t) } func TestGetters(t *testing.T) { var tests = []struct { name string header *Header tcpSourceAddr, tcpDestAddr *net.TCPAddr udpSourceAddr, udpDestAddr *net.UDPAddr unixSourceAddr, unixDestAddr *net.UnixAddr ipSource, ipDest net.IP portSource, portDest int }{ { name: "TCPv4", header: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, }, tcpSourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, tcpDestAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, ipSource: net.ParseIP(testSourceIPv4Addr), ipDest: net.ParseIP(testDestinationIPv4Addr), portSource: 1000, portDest: 2000, }, { name: "UDPv4", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv6, SourceAddr: &net.UDPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.UDPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, }, udpSourceAddr: &net.UDPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, udpDestAddr: &net.UDPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, ipSource: net.ParseIP(testSourceIPv4Addr), ipDest: net.ParseIP(testDestinationIPv4Addr), portSource: 1000, portDest: 2000, }, { name: "UnixStream", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixStream, SourceAddr: &net.UnixAddr{ Net: "unix", Name: "src", }, DestinationAddr: &net.UnixAddr{ Net: "unix", Name: "dst", }, }, unixSourceAddr: &net.UnixAddr{ Net: "unix", Name: "src", }, unixDestAddr: &net.UnixAddr{ Net: "unix", Name: "dst", }, }, { name: "UnixDatagram", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixDatagram, SourceAddr: &net.UnixAddr{ Net: "unix", Name: "src", }, DestinationAddr: &net.UnixAddr{ Net: "unix", Name: "dst", }, }, unixSourceAddr: &net.UnixAddr{ Net: "unix", Name: "src", }, unixDestAddr: &net.UnixAddr{ Net: "unix", Name: "dst", }, }, { name: "Unspec", header: &Header{ Version: 1, Command: PROXY, TransportProtocol: UNSPEC, }, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { tcpSourceAddr, tcpDestAddr, _ := test.header.TCPAddrs() if test.tcpSourceAddr != nil && !reflect.DeepEqual(tcpSourceAddr, test.tcpSourceAddr) { t.Errorf("TCPAddrs() source = %v, want %v", tcpSourceAddr, test.tcpSourceAddr) } if test.tcpDestAddr != nil && !reflect.DeepEqual(tcpDestAddr, test.tcpDestAddr) { t.Errorf("TCPAddrs() dest = %v, want %v", tcpDestAddr, test.tcpDestAddr) } udpSourceAddr, udpDestAddr, _ := test.header.UDPAddrs() if test.udpSourceAddr != nil && !reflect.DeepEqual(udpSourceAddr, test.udpSourceAddr) { t.Errorf("TCPAddrs() source = %v, want %v", udpSourceAddr, test.udpSourceAddr) } if test.udpDestAddr != nil && !reflect.DeepEqual(udpDestAddr, test.udpDestAddr) { t.Errorf("TCPAddrs() dest = %v, want %v", udpDestAddr, test.udpDestAddr) } unixSourceAddr, unixDestAddr, _ := test.header.UnixAddrs() if test.unixSourceAddr != nil && !reflect.DeepEqual(unixSourceAddr, test.unixSourceAddr) { t.Errorf("UnixAddrs() source = %v, want %v", unixSourceAddr, test.unixSourceAddr) } if test.unixDestAddr != nil && !reflect.DeepEqual(unixDestAddr, test.unixDestAddr) { t.Errorf("UnixAddrs() dest = %v, want %v", unixDestAddr, test.unixDestAddr) } ipSource, ipDest, _ := test.header.IPs() if test.ipSource != nil && !ipSource.Equal(test.ipSource) { t.Errorf("IPs() source = %v, want %v", ipSource, test.ipSource) } if test.ipDest != nil && !ipDest.Equal(test.ipDest) { t.Errorf("IPs() dest = %v, want %v", ipDest, test.ipDest) } portSource, portDest, _ := test.header.Ports() if test.portSource != 0 && portSource != test.portSource { t.Errorf("Ports() source = %v, want %v", portSource, test.portSource) } if test.portDest != 0 && portDest != test.portDest { t.Errorf("Ports() dest = %v, want %v", portDest, test.portDest) } }) } } func TestSetTLVs(t *testing.T) { tests := []struct { header *Header name string tlvs []TLV expectErr bool }{ { name: "add authority TLV", header: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, }, tlvs: []TLV{{ Type: PP2_TYPE_AUTHORITY, Value: []byte("example.org"), }}, }, { name: "add too long TLV", header: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, }, tlvs: []TLV{{ Type: PP2_TYPE_AUTHORITY, Value: append(bytes.Repeat([]byte("a"), 0xFFFF), []byte(".example.org")...), }}, expectErr: true, }, } for _, tt := range tests { err := tt.header.SetTLVs(tt.tlvs) if err != nil && !tt.expectErr { t.Fatalf("shouldn't have thrown error %q", err.Error()) } } } func TestWriteTo(t *testing.T) { var buf bytes.Buffer validHeader := &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, } if _, err := validHeader.WriteTo(&buf); err != nil { t.Fatalf("shouldn't have thrown error %q", err.Error()) } invalidHeader := &Header{ SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, } if _, err := invalidHeader.WriteTo(&buf); err == nil { t.Fatalf("should have thrown error %q", err.Error()) } } func TestFormat(t *testing.T) { validHeader := &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, } if _, err := validHeader.Format(); err != nil { t.Fatalf("shouldn't have thrown error %q", err.Error()) } } func TestFormatInvalid(t *testing.T) { tests := []struct { name string header *Header err error }{ { name: "invalidVersion", header: &Header{ Version: 3, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, }, err: ErrUnknownProxyProtocolVersion, }, { name: "v2MismatchTCPv4_UDPv4", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4UDPAddr, DestinationAddr: v4addr, }, err: ErrInvalidAddress, }, { name: "v2MismatchTCPv4_TCPv6", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v6addr, }, err: ErrInvalidAddress, }, { name: "v2MismatchUnixStream_TCPv4", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixStream, SourceAddr: v4addr, DestinationAddr: unixStreamAddr, }, err: ErrInvalidAddress, }, { name: "v1MismatchTCPv4_TCPv6", header: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v6addr, DestinationAddr: v4addr, }, err: ErrInvalidAddress, }, { name: "v1MismatchTCPv4_UDPv4", header: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4UDPAddr, DestinationAddr: v4addr, }, err: ErrInvalidAddress, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { if _, err := test.header.Format(); err == nil { t.Errorf("Header.Format() succeeded, want an error") } else if err != test.err { t.Errorf("Header.Format() = %q, want %q", err, test.err) } }) } } func TestHeaderProxyFromAddrs(t *testing.T) { unspec := &Header{ Version: 2, Command: LOCAL, TransportProtocol: UNSPEC, } tests := []struct { name string version byte sourceAddr, destAddr net.Addr expected *Header }{ { name: "TCPv4", sourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, destAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, expected: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, }, }, { name: "TCPv6", sourceAddr: &net.TCPAddr{ IP: net.ParseIP("fde7::372"), Port: 1000, }, destAddr: &net.TCPAddr{ IP: net.ParseIP("fde7::1"), Port: 2000, }, expected: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("fde7::372"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("fde7::1"), Port: 2000, }, }, }, { name: "UDPv4", sourceAddr: &net.UDPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, destAddr: &net.UDPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, expected: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, }, }, { name: "UDPv6", sourceAddr: &net.UDPAddr{ IP: net.ParseIP("fde7::372"), Port: 1000, }, destAddr: &net.UDPAddr{ IP: net.ParseIP("fde7::1"), Port: 2000, }, expected: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv6, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("fde7::372"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("fde7::1"), Port: 2000, }, }, }, { name: "UnixStream", sourceAddr: &net.UnixAddr{ Net: "unix", Name: "src", }, destAddr: &net.UnixAddr{ Net: "unix", Name: "dst", }, expected: &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixStream, SourceAddr: &net.UnixAddr{ Net: "unix", Name: "src", }, DestinationAddr: &net.UnixAddr{ Net: "unix", Name: "dst", }, }, }, { name: "UnixDatagram", sourceAddr: &net.UnixAddr{ Net: "unixgram", Name: "src", }, destAddr: &net.UnixAddr{ Net: "unixgram", Name: "dst", }, expected: &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixDatagram, SourceAddr: &net.UnixAddr{ Net: "unixgram", Name: "src", }, DestinationAddr: &net.UnixAddr{ Net: "unixgram", Name: "dst", }, }, }, { name: "Version1", version: 1, sourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, destAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, expected: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, }, }, { name: "TCPInvalidIP", sourceAddr: &net.TCPAddr{ IP: nil, Port: 1000, }, destAddr: &net.TCPAddr{ IP: nil, Port: 2000, }, expected: unspec, }, { name: "UDPInvalidIP", sourceAddr: &net.UDPAddr{ IP: nil, Port: 1000, }, destAddr: &net.UDPAddr{ IP: nil, Port: 2000, }, expected: unspec, }, { name: "TCPAddrTypeMismatch", sourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, destAddr: &net.UDPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, expected: unspec, }, { name: "UDPAddrTypeMismatch", sourceAddr: &net.UDPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, destAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, expected: unspec, }, { name: "UnixAddrTypeMismatch", sourceAddr: &net.UnixAddr{ Net: "unix", }, destAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, expected: unspec, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := HeaderProxyFromAddrs(tt.version, tt.sourceAddr, tt.destAddr) if !h.EqualsTo(tt.expected) { t.Errorf("expected %+v, actual %+v for source %+v and destination %+v", tt.expected, h, tt.sourceAddr, tt.destAddr) } }) } } pires-go-proxyproto-04c9ad1/helper/000077500000000000000000000000001514137054000173705ustar00rootroot00000000000000pires-go-proxyproto-04c9ad1/helper/http2/000077500000000000000000000000001514137054000204315ustar00rootroot00000000000000pires-go-proxyproto-04c9ad1/helper/http2/http2.go000066400000000000000000000124601514137054000220240ustar00rootroot00000000000000// Package http2 provides helpers for HTTP/2. package http2 import ( "context" "crypto/tls" "fmt" "log" "net" "net/http" "sync" "time" "github.com/pires/go-proxyproto" "golang.org/x/net/http2" ) const listenerRetryBaseDelay = 5 * time.Millisecond // Server is an HTTP server accepting both regular and proxied, both HTTP/1 and // HTTP/2 connections. // // HTTP/2 is negotiated using TLS ALPN, either directly via a tls.Conn, either // indirectly via the PROXY protocol. When the PROXY protocol is used, the // TLS-terminating proxy in front of the server must be configured to accept // the "h2" TLS ALPN protocol. // // The server is closed when the http.Server is. type Server struct { h1 *http.Server // regular HTTP/1 server h2 *http2.Server // HTTP/2 server h2Err error // HTTP/2 server setup error, if any h1Listener h1Listener // pipe listener for the HTTP/1 server // The following fields are protected by the mutex mu sync.Mutex closed bool listeners map[net.Listener]struct{} } // NewServer creates a new HTTP server. // // A nil h2 is equivalent to a zero http2.Server. func NewServer(h1 *http.Server, h2 *http2.Server) *Server { if h2 == nil { h2 = new(http2.Server) } srv := &Server{ h1: h1, h2: h2, h2Err: http2.ConfigureServer(h1, h2), listeners: make(map[net.Listener]struct{}), } srv.h1Listener = h1Listener{newPipeListener(), srv} go func() { // proxyListener.Accept never fails _ = h1.Serve(srv.h1Listener) }() return srv } func (srv *Server) errorLog() *log.Logger { if srv.h1.ErrorLog != nil { return srv.h1.ErrorLog } return log.Default() } // Serve accepts incoming connections on the listener ln. func (srv *Server) Serve(ln net.Listener) error { if srv.h2Err != nil { return srv.h2Err } srv.mu.Lock() ok := !srv.closed if ok { srv.listeners[ln] = struct{}{} } srv.mu.Unlock() if !ok { return http.ErrServerClosed } defer func() { srv.mu.Lock() delete(srv.listeners, ln) srv.mu.Unlock() }() // net.Listener.Accept can fail for temporary failures, e.g. too many open // files or other timeout conditions. In that case, wait and retry later. // This mirrors what the net/http package does. var delay time.Duration for { conn, err := ln.Accept() if ne, ok := err.(net.Error); ok && ne.Timeout() { if delay == 0 { delay = listenerRetryBaseDelay } else { delay *= 2 } if maxDelay := 1 * time.Second; delay > maxDelay { delay = maxDelay } srv.errorLog().Printf("listener %q: accept error (retrying in %v): %v", ln.Addr(), delay, err) time.Sleep(delay) } else if err != nil { if srv.isClosed() { return http.ErrServerClosed } return fmt.Errorf("failed to accept connection: %w", err) } delay = 0 baseCtx := context.Background() if srv.h1.BaseContext != nil { baseCtx = srv.h1.BaseContext(ln) } go func(baseCtx context.Context, conn net.Conn) { if err := srv.serveConn(baseCtx, conn); err != nil { srv.errorLog().Printf("listener %q: %v", ln.Addr(), err) } }(baseCtx, conn) } } func (srv *Server) serveConn(baseCtx context.Context, conn net.Conn) error { var proto string switch conn := conn.(type) { case *tls.Conn: if err := conn.Handshake(); err != nil { if closeErr := conn.Close(); closeErr != nil { srv.errorLog().Printf("failed to close connection: %v", closeErr) } return err } proto = conn.ConnectionState().NegotiatedProtocol case *proxyproto.Conn: if proxyHeader := conn.ProxyHeader(); proxyHeader != nil { tlvs, err := proxyHeader.TLVs() if err != nil { if closeErr := conn.Close(); closeErr != nil { srv.errorLog().Printf("failed to close connection: %v", closeErr) } return err } for _, tlv := range tlvs { if tlv.Type == proxyproto.PP2_TYPE_ALPN { proto = string(tlv.Value) break } } } } // See https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids switch proto { case http2.NextProtoTLS, "h2c": defer func() { if closeErr := conn.Close(); closeErr != nil { srv.errorLog().Printf("failed to close connection: %v", closeErr) } }() ctx := baseCtx // Mirror net/http.Server ConnContext behavior. if cc := srv.h1.ConnContext; cc != nil { ctx = cc(ctx, conn) if ctx == nil { panic("ConnContext returned nil") } } opts := http2.ServeConnOpts{Context: ctx, BaseConfig: srv.h1} srv.h2.ServeConn(conn, &opts) return nil case "", "http/1.0", "http/1.1": return srv.h1Listener.ServeConn(conn) default: if closeErr := conn.Close(); closeErr != nil { srv.errorLog().Printf("failed to close connection: %v", closeErr) } return fmt.Errorf("unsupported protocol %q", proto) } } func (srv *Server) closeListeners() error { srv.mu.Lock() defer srv.mu.Unlock() srv.closed = true var err error for ln := range srv.listeners { if cerr := ln.Close(); cerr != nil { err = cerr } } return err } func (srv *Server) isClosed() bool { srv.mu.Lock() defer srv.mu.Unlock() return srv.closed } // h1Listener is used to signal back http.Server's Close and Shutdown to the // HTTP/2 server. type h1Listener struct { *pipeListener srv *Server } func (ln h1Listener) Close() error { // pipeListener.Close never fails _ = ln.pipeListener.Close() return ln.srv.closeListeners() } pires-go-proxyproto-04c9ad1/helper/http2/http2_internal_test.go000066400000000000000000000033131514137054000247540ustar00rootroot00000000000000package http2 import ( "context" "net" "net/http" "testing" "time" "github.com/pires/go-proxyproto" ) // TestServeConn_ConnContextReturnsNil lives in package http2 (not http2_test) so // it can call the unexported serveConn method directly and recover the panic in // the same goroutine, which is not possible through the public Serve API because // Serve spawns a new goroutine per connection. func TestServeConn_ConnContextReturnsNil(t *testing.T) { srv := NewServer(&http.Server{ ReadHeaderTimeout: 5 * time.Second, Handler: http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}), ConnContext: func(_ context.Context, _ net.Conn) context.Context { return nil }, }, nil) // Create a pipe and write a PROXY header with h2 ALPN to trigger the H2 path. clientConn, serverConn := net.Pipe() defer func() { _ = clientConn.Close() }() defer func() { _ = serverConn.Close() }() header := proxyproto.Header{ Version: 2, Command: proxyproto.LOCAL, TransportProtocol: proxyproto.UNSPEC, } if err := header.SetTLVs([]proxyproto.TLV{{ Type: proxyproto.PP2_TYPE_ALPN, Value: []byte("h2"), }}); err != nil { t.Fatalf("failed to set TLVs: %v", err) } // Write the header in a goroutine because net.Pipe is synchronous. go func() { _, _ = header.WriteTo(clientConn) _ = clientConn.Close() }() pConn := proxyproto.NewConn(serverConn) defer func() { r := recover() if r == nil { t.Fatal("expected panic from ConnContext returning nil") } msg, ok := r.(string) if !ok || msg != "ConnContext returned nil" { t.Fatalf("expected panic message 'ConnContext returned nil', got: %v", r) } }() _ = srv.serveConn(context.Background(), pConn) } pires-go-proxyproto-04c9ad1/helper/http2/http2_test.go000066400000000000000000000216651514137054000230720ustar00rootroot00000000000000package http2_test import ( "context" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "errors" "log" "math/big" "net" "net/http" "testing" "time" "github.com/pires/go-proxyproto" h2proxy "github.com/pires/go-proxyproto/helper/http2" "golang.org/x/net/http2" ) func ExampleServer() { ln, err := net.Listen("tcp", "localhost:80") if err != nil { log.Fatalf("failed to listen: %v", err) } proxyLn := &proxyproto.Listener{ Listener: ln, } server := h2proxy.NewServer(&http.Server{ ReadHeaderTimeout: 5 * time.Second, Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("Hello world!\n")) }), }, nil) if err := server.Serve(proxyLn); err != nil { log.Fatalf("failed to serve: %v", err) } } type contextKey string const ( connContextKey = contextKey("conn") baseContextKey = contextKey("base") ) func TestServer_h1(t *testing.T) { addr, server := newTestServer(t) t.Cleanup(func() { if err := server.Close(); err != nil { t.Errorf("failed to close server: %v", err) } }) resp, err := http.Get("http://" + addr) if err != nil { t.Fatalf("failed to perform HTTP request: %v", err) } if err := resp.Body.Close(); err != nil { t.Fatalf("failed to close response body: %v", err) } } func TestServer_h2(t *testing.T) { addr, server := newTestServer(t) t.Cleanup(func() { if err := server.Close(); err != nil { t.Errorf("failed to close server: %v", err) } }) conn, err := net.Dial("tcp", addr) if err != nil { t.Fatalf("failed to dial: %v", err) } defer func() { if err := conn.Close(); err != nil { t.Errorf("failed to close connection: %v", err) } }() proxyHeader := proxyproto.Header{ Version: 2, Command: proxyproto.LOCAL, TransportProtocol: proxyproto.UNSPEC, } tlvs := []proxyproto.TLV{{ Type: proxyproto.PP2_TYPE_ALPN, Value: []byte("h2"), }} if err := proxyHeader.SetTLVs(tlvs); err != nil { t.Fatalf("failed to set TLVs: %v", err) } if _, err := proxyHeader.WriteTo(conn); err != nil { t.Fatalf("failed to write PROXY header: %v", err) } h2Conn, err := new(http2.Transport).NewClientConn(conn) if err != nil { t.Fatalf("failed to create HTTP connection: %v", err) } req, err := http.NewRequest(http.MethodGet, "http://"+addr, nil) if err != nil { t.Fatalf("failed to create HTTP request: %v", err) } resp, err := h2Conn.RoundTrip(req) if err != nil { t.Fatalf("failed to perform HTTP request: %v", err) } if err := resp.Body.Close(); err != nil { t.Fatalf("failed to close response body: %v", err) } } func TestServer_h2_tls(t *testing.T) { addr, server := newTLSTestServer(t) t.Cleanup(func() { if err := server.Close(); err != nil { t.Errorf("failed to close server: %v", err) } }) conn, err := tls.Dial("tcp", addr, &tls.Config{ InsecureSkipVerify: true, //nolint:gosec // skipping certificate verification for testing. NextProtos: []string{http2.NextProtoTLS}, }) if err != nil { t.Fatalf("failed to dial: %v", err) } defer func() { if err := conn.Close(); err != nil { t.Errorf("failed to close connection: %v", err) } }() h2Conn, err := new(http2.Transport).NewClientConn(conn) if err != nil { t.Fatalf("failed to create HTTP connection: %v", err) } req, err := http.NewRequest(http.MethodGet, "https://"+addr, nil) if err != nil { t.Fatalf("failed to create HTTP request: %v", err) } resp, err := h2Conn.RoundTrip(req) if err != nil { t.Fatalf("failed to perform HTTP request: %v", err) } if err := resp.Body.Close(); err != nil { t.Errorf("failed to close response body: %v", err) } } func TestServer_h1_nil_ConnContext(t *testing.T) { addr, server := newTestServerWithoutConnContext(t) t.Cleanup(func() { if err := server.Close(); err != nil { t.Errorf("failed to close server: %v", err) } }) resp, err := http.Get("http://" + addr) if err != nil { t.Fatalf("failed to perform HTTP request: %v", err) } if err := resp.Body.Close(); err != nil { t.Fatalf("failed to close response body: %v", err) } } func TestServer_h2_nil_ConnContext(t *testing.T) { addr, server := newTestServerWithoutConnContext(t) t.Cleanup(func() { if err := server.Close(); err != nil { t.Errorf("failed to close server: %v", err) } }) conn, err := net.Dial("tcp", addr) if err != nil { t.Fatalf("failed to dial: %v", err) } defer func() { if err := conn.Close(); err != nil { t.Errorf("failed to close connection: %v", err) } }() proxyHeader := proxyproto.Header{ Version: 2, Command: proxyproto.LOCAL, TransportProtocol: proxyproto.UNSPEC, } tlvs := []proxyproto.TLV{{ Type: proxyproto.PP2_TYPE_ALPN, Value: []byte("h2"), }} if err := proxyHeader.SetTLVs(tlvs); err != nil { t.Fatalf("failed to set TLVs: %v", err) } if _, err := proxyHeader.WriteTo(conn); err != nil { t.Fatalf("failed to write PROXY header: %v", err) } h2Conn, err := new(http2.Transport).NewClientConn(conn) if err != nil { t.Fatalf("failed to create HTTP connection: %v", err) } req, err := http.NewRequest(http.MethodGet, "http://"+addr, nil) if err != nil { t.Fatalf("failed to create HTTP request: %v", err) } resp, err := h2Conn.RoundTrip(req) if err != nil { t.Fatalf("failed to perform HTTP request: %v", err) } if err := resp.Body.Close(); err != nil { t.Fatalf("failed to close response body: %v", err) } } // startTestServer listens on a random port, wraps the listener with wrapListener // (or a proxyproto.Listener if nil), and starts an h2proxy.Server in the background. // It registers cleanup to wait for the server to finish. func startTestServer(t *testing.T, server *http.Server, wrapListener func(net.Listener) net.Listener) string { t.Helper() ln, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("failed to listen: %v", err) } var serveLn net.Listener if wrapListener != nil { serveLn = wrapListener(ln) } else { serveLn = &proxyproto.Listener{Listener: ln} } h2Server := h2proxy.NewServer(server, nil) done := make(chan error, 1) go func() { done <- h2Server.Serve(serveLn) }() t.Cleanup(func() { err := <-done if err != nil && !errors.Is(err, http.ErrServerClosed) { t.Fatalf("failed to serve: %v", err) } }) return ln.Addr().String() } func newTestServer(t *testing.T) (addr string, server *http.Server) { server = newContextAssertingServer(t) return startTestServer(t, server, nil), server } func newTLSTestServer(t *testing.T) (addr string, server *http.Server) { server = newContextAssertingServer(t) return startTestServer(t, server, func(ln net.Listener) net.Listener { return tls.NewListener(ln, testTLSConfig(t)) }), server } func newTestServerWithoutConnContext(t *testing.T) (addr string, server *http.Server) { server = &http.Server{ ReadHeaderTimeout: 5 * time.Second, Handler: http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}), } return startTestServer(t, server, nil), server } // newContextAssertingServer returns an http.Server that asserts connContextKey // and baseContextKey are present in every request's context. func newContextAssertingServer(t *testing.T) *http.Server { return &http.Server{ ReadHeaderTimeout: 5 * time.Second, Handler: http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { if v := r.Context().Value(connContextKey); v == nil { t.Errorf("http.Request.Context missing connContextKey") } if v := r.Context().Value(baseContextKey); v == nil { t.Errorf("http.Request.Context missing baseContextKey") } }), BaseContext: func(_ net.Listener) context.Context { return context.WithValue(context.Background(), baseContextKey, struct{}{}) }, ConnContext: func(ctx context.Context, _ net.Conn) context.Context { return context.WithValue(ctx, connContextKey, struct{}{}) }, } } func testTLSConfig(t *testing.T) *tls.Config { t.Helper() key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatalf("failed to generate key: %v", err) } serial, err := rand.Int(rand.Reader, big.NewInt(1<<62)) if err != nil { t.Fatalf("failed to generate serial: %v", err) } template := x509.Certificate{ SerialNumber: serial, Subject: pkix.Name{ CommonName: "localhost", }, NotBefore: time.Now().Add(-time.Hour), NotAfter: time.Now().Add(time.Hour), KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, DNSNames: []string{"localhost"}, IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, } der, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) if err != nil { t.Fatalf("failed to create cert: %v", err) } cert := tls.Certificate{ Certificate: [][]byte{der}, PrivateKey: key, } return &tls.Config{ MinVersion: tls.VersionTLS12, Certificates: []tls.Certificate{cert}, NextProtos: []string{http2.NextProtoTLS}, } } pires-go-proxyproto-04c9ad1/helper/http2/listener.go000066400000000000000000000021261514137054000226060ustar00rootroot00000000000000package http2 import ( "net" "sync" ) // pipeListener is a hack to workaround the lack of http.Server.ServeConn. // See: https://github.com/golang/go/issues/36673 type pipeListener struct { ch chan net.Conn closed bool mu sync.Mutex } func newPipeListener() *pipeListener { return &pipeListener{ ch: make(chan net.Conn, 64), } } func (ln *pipeListener) Accept() (net.Conn, error) { conn, ok := <-ln.ch if !ok { return nil, net.ErrClosed } return conn, nil } func (ln *pipeListener) Close() error { ln.mu.Lock() defer ln.mu.Unlock() if ln.closed { return net.ErrClosed } ln.closed = true close(ln.ch) return nil } // ServeConn enqueues a new connection. The connection will be returned in the // next Accept call. func (ln *pipeListener) ServeConn(conn net.Conn) error { ln.mu.Lock() defer ln.mu.Unlock() if ln.closed { return net.ErrClosed } ln.ch <- conn return nil } func (ln *pipeListener) Addr() net.Addr { return pipeAddr{} } type pipeAddr struct{} func (pipeAddr) Network() string { return "pipe" } func (pipeAddr) String() string { return "pipe" } pires-go-proxyproto-04c9ad1/policy.go000066400000000000000000000237541514137054000177520ustar00rootroot00000000000000package proxyproto import ( "fmt" "net" "strings" ) // PolicyFunc can be used to decide whether to trust the PROXY info from // upstream. If set, the connecting address is passed in as an argument. // // See below for the different policies. // // In case an error is returned the connection is denied. // // Deprecated: use ConnPolicyFunc instead. type PolicyFunc func(upstream net.Addr) (Policy, error) // ConnPolicyFunc can be used to decide whether to trust the PROXY info // based on connection policy options. If set, the connecting addresses // (remote and local) are passed in as argument. // // See below for the different policies. // // In case an error is returned the connection is denied. type ConnPolicyFunc func(connPolicyOptions ConnPolicyOptions) (Policy, error) // ConnPolicyOptions contains the remote and local addresses of a connection. type ConnPolicyOptions struct { Upstream net.Addr Downstream net.Addr } // Policy defines how a connection with a PROXY header address is treated. type Policy int const ( // USE address from PROXY header. USE Policy = iota // IGNORE address from PROXY header, but accept connection. IGNORE // REJECT connection when PROXY header is sent // Note: even though the first read on the connection returns an error if // a PROXY header is present, subsequent reads do not. It is the task of // the code using the connection to handle that case properly. REJECT // REQUIRE connection to send PROXY header, reject if not present // Note: even though the first read on the connection returns an error if // a PROXY header is not present, subsequent reads do not. It is the task // of the code using the connection to handle that case properly. REQUIRE // SKIP accepts a connection without requiring the PROXY header. // Note: an example usage can be found in the SkipProxyHeaderForCIDR // function. SKIP ) // ConnSkipProxyHeaderForCIDR returns a ConnPolicyFunc which can be used to accept // a connection from a skipHeaderCIDR without requiring a PROXY header, e.g. // Kubernetes pods local traffic. The def is a policy to use when an upstream // address doesn't match the skipHeaderCIDR. func ConnSkipProxyHeaderForCIDR(skipHeaderCIDR *net.IPNet, def Policy) ConnPolicyFunc { return func(connOpts ConnPolicyOptions) (Policy, error) { ip, err := ipFromAddr(connOpts.Upstream) if err != nil { return def, err } if skipHeaderCIDR != nil && skipHeaderCIDR.Contains(ip) { return SKIP, nil } return def, nil } } // SkipProxyHeaderForCIDR returns a PolicyFunc which can be used to accept a // connection from a skipHeaderCIDR without requiring a PROXY header, e.g. // Kubernetes pods local traffic. The def is a policy to use when an upstream // address doesn't match the skipHeaderCIDR. // // Deprecated: use ConnSkipProxyHeaderForCIDR instead. func SkipProxyHeaderForCIDR(skipHeaderCIDR *net.IPNet, def Policy) PolicyFunc { connPolicy := ConnSkipProxyHeaderForCIDR(skipHeaderCIDR, def) return func(upstream net.Addr) (Policy, error) { return connPolicy(ConnPolicyOptions{Upstream: upstream}) } } // WithPolicy adds given policy to a connection when passed as option to NewConn(). func WithPolicy(p Policy) func(*Conn) { return func(c *Conn) { c.ProxyHeaderPolicy = p } } // ConnLaxWhiteListPolicy returns a ConnPolicyFunc which decides whether the // upstream ip is allowed to send a proxy header based on a list of allowed // IP addresses and IP ranges. In case upstream IP is not in list the proxy // header will be ignored. If one of the provided IP addresses or IP ranges // is invalid it will return an error instead of a ConnPolicyFunc. func ConnLaxWhiteListPolicy(allowed []string) (ConnPolicyFunc, error) { allowFrom, err := parse(allowed) if err != nil { return nil, err } return connWhitelistPolicy(allowFrom, IGNORE), nil } // LaxWhiteListPolicy returns a PolicyFunc which decides whether the // upstream ip is allowed to send a proxy header based on a list of allowed // IP addresses and IP ranges. In case upstream IP is not in list the proxy // header will be ignored. If one of the provided IP addresses or IP ranges // is invalid it will return an error instead of a PolicyFunc. // // Deprecated: use ConnLaxWhiteListPolicy instead. func LaxWhiteListPolicy(allowed []string) (PolicyFunc, error) { connPolicy, err := ConnLaxWhiteListPolicy(allowed) if err != nil { return nil, err } return func(upstream net.Addr) (Policy, error) { return connPolicy(ConnPolicyOptions{Upstream: upstream}) }, nil } // ConnMustLaxWhiteListPolicy returns a ConnLaxWhiteListPolicy but will panic // if one of the provided IP addresses or IP ranges is invalid. func ConnMustLaxWhiteListPolicy(allowed []string) ConnPolicyFunc { pfunc, err := ConnLaxWhiteListPolicy(allowed) if err != nil { panic(err) } return pfunc } // MustLaxWhiteListPolicy returns a LaxWhiteListPolicy but will panic if one // of the provided IP addresses or IP ranges is invalid. // // Deprecated: use ConnMustLaxWhiteListPolicy instead. func MustLaxWhiteListPolicy(allowed []string) PolicyFunc { connPolicy := ConnMustLaxWhiteListPolicy(allowed) return func(upstream net.Addr) (Policy, error) { return connPolicy(ConnPolicyOptions{Upstream: upstream}) } } // ConnStrictWhiteListPolicy returns a ConnPolicyFunc which decides whether the // upstream ip is allowed to send a proxy header based on a list of allowed // IP addresses and IP ranges. In case upstream IP is not in list reading on // the connection will be refused on the first read. Please note: subsequent // reads do not error. It is the task of the code using the connection to // handle that case properly. If one of the provided IP addresses or IP // ranges is invalid it will return an error instead of a ConnPolicyFunc. func ConnStrictWhiteListPolicy(allowed []string) (ConnPolicyFunc, error) { allowFrom, err := parse(allowed) if err != nil { return nil, err } return connWhitelistPolicy(allowFrom, REJECT), nil } // StrictWhiteListPolicy returns a PolicyFunc which decides whether the // upstream ip is allowed to send a proxy header based on a list of allowed // IP addresses and IP ranges. In case upstream IP is not in list reading on // the connection will be refused on the first read. Please note: subsequent // reads do not error. It is the task of the code using the connection to // handle that case properly. If one of the provided IP addresses or IP // ranges is invalid it will return an error instead of a PolicyFunc. // // Deprecated: use ConnStrictWhiteListPolicy instead. func StrictWhiteListPolicy(allowed []string) (PolicyFunc, error) { connPolicy, err := ConnStrictWhiteListPolicy(allowed) if err != nil { return nil, err } return func(upstream net.Addr) (Policy, error) { return connPolicy(ConnPolicyOptions{Upstream: upstream}) }, nil } // ConnMustStrictWhiteListPolicy returns a ConnStrictWhiteListPolicy but will panic // if one of the provided IP addresses or IP ranges is invalid. func ConnMustStrictWhiteListPolicy(allowed []string) ConnPolicyFunc { pfunc, err := ConnStrictWhiteListPolicy(allowed) if err != nil { panic(err) } return pfunc } // MustStrictWhiteListPolicy returns a StrictWhiteListPolicy but will panic // if one of the provided IP addresses or IP ranges is invalid. // // Deprecated: use ConnMustStrictWhiteListPolicy instead. func MustStrictWhiteListPolicy(allowed []string) PolicyFunc { connPolicy := ConnMustStrictWhiteListPolicy(allowed) return func(upstream net.Addr) (Policy, error) { return connPolicy(ConnPolicyOptions{Upstream: upstream}) } } func connWhitelistPolicy(allowed []func(net.IP) bool, def Policy) ConnPolicyFunc { return func(connOpts ConnPolicyOptions) (Policy, error) { upstreamIP, err := ipFromAddr(connOpts.Upstream) if err != nil { // something is wrong with the source IP, better reject the connection return REJECT, err } for _, allowFrom := range allowed { if allowFrom(upstreamIP) { return USE, nil } } return def, nil } } func parse(allowed []string) ([]func(net.IP) bool, error) { a := make([]func(net.IP) bool, len(allowed)) for i, allowFrom := range allowed { if strings.LastIndex(allowFrom, "/") > 0 { _, ipRange, err := net.ParseCIDR(allowFrom) if err != nil { return nil, fmt.Errorf("proxyproto: given string %q is not a valid IP range: %v", allowFrom, err) } a[i] = ipRange.Contains } else { allowed := net.ParseIP(allowFrom) if allowed == nil { return nil, fmt.Errorf("proxyproto: given string %q is not a valid IP address", allowFrom) } a[i] = allowed.Equal } } return a, nil } func ipFromAddr(upstream net.Addr) (net.IP, error) { upstreamString, _, err := net.SplitHostPort(upstream.String()) if err != nil { return nil, err } upstreamIP := net.ParseIP(upstreamString) if nil == upstreamIP { return nil, fmt.Errorf("proxyproto: invalid IP address") } return upstreamIP, nil } // TrustProxyHeaderFrom returns a ConnPolicyFunc which can be used to decide // whether to use or reject PROXY headers based on the source IP of the // connection. This policy ensures that only trusted sources can set the PROXY // header. Connections from IPs not in the trusted list will be rejected. func TrustProxyHeaderFrom(trustedIPs ...net.IP) ConnPolicyFunc { return func(connOpts ConnPolicyOptions) (Policy, error) { ip, err := ipFromAddr(connOpts.Upstream) if err != nil { return REJECT, err } for _, trustedIP := range trustedIPs { if trustedIP.Equal(ip) { return USE, nil } } return REJECT, nil } } // IgnoreProxyHeaderNotOnInterface returns a ConnPolicyFunc which can be used to // decide whether to use or ignore PROXY headers depending on the connection // being made on specific interfaces. This policy can be used when the server // is bound to multiple interfaces but wants to allow on one or more interfaces. func IgnoreProxyHeaderNotOnInterface(allowedIP net.IP) ConnPolicyFunc { return func(connOpts ConnPolicyOptions) (Policy, error) { ip, err := ipFromAddr(connOpts.Downstream) if err != nil { return REJECT, err } if allowedIP.Equal(ip) { return USE, nil } return IGNORE, nil } } pires-go-proxyproto-04c9ad1/policy_test.go000066400000000000000000000273271514137054000210110ustar00rootroot00000000000000package proxyproto import ( "net" "testing" ) type failingAddr struct{} func (f failingAddr) Network() string { return "failing" } func (f failingAddr) String() string { return "failing" } type invalidIPAddr struct{} func (i invalidIPAddr) Network() string { return "tcp" } func (i invalidIPAddr) String() string { return "999.999.999.999:1234" } func TestWhitelistPolicyReturnsErrorOnInvalidAddress(t *testing.T) { var cases = []struct { name string policy PolicyFunc }{ {"strict whitelist policy", MustStrictWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4", "10.0.0.0/30"})}, {"lax whitelist policy", MustLaxWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4", "10.0.0.0/30"})}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { _, err := tc.policy(failingAddr{}) if err == nil { t.Fatal("Expected error, got none") } }) } } func TestWhitelistPolicyReturnsErrorOnInvalidIP(t *testing.T) { policies := []struct { name string policy ConnPolicyFunc }{ {"conn strict whitelist policy", ConnMustStrictWhiteListPolicy([]string{"10.0.0.3"})}, {"conn lax whitelist policy", ConnMustLaxWhiteListPolicy([]string{"10.0.0.3"})}, } for _, tc := range policies { t.Run(tc.name, func(t *testing.T) { _, err := tc.policy(ConnPolicyOptions{Upstream: invalidIPAddr{}}) if err == nil { t.Fatal("Expected error, got none") } }) } } func TestStrictWhitelistPolicyReturnsRejectWhenUpstreamIpAddrNotInWhitelist(t *testing.T) { var cases = []struct { name string policy PolicyFunc }{ {"strict whitelist policy", MustStrictWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4", "10.0.0.0/30"})}, } upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.5:45738") if err != nil { t.Fatalf("err: %v", err) } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { policy, err := tc.policy(upstream) if err != nil { t.Fatalf("err: %v", err) } if policy != REJECT { t.Fatalf("Expected policy REJECT, got %v", policy) } }) } } func TestLaxWhitelistPolicyReturnsIgnoreWhenUpstreamIpAddrNotInWhitelist(t *testing.T) { var cases = []struct { name string policy PolicyFunc }{ {"lax whitelist policy", MustLaxWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4", "10.0.0.0/30"})}, } upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.5:45738") if err != nil { t.Fatalf("err: %v", err) } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { policy, err := tc.policy(upstream) if err != nil { t.Fatalf("err: %v", err) } if policy != IGNORE { t.Fatalf("Expected policy IGNORE, got %v", policy) } }) } } func TestWhitelistPolicyReturnsUseWhenUpstreamIpAddrInWhitelist(t *testing.T) { var cases = []struct { name string policy PolicyFunc }{ {"strict whitelist policy", MustStrictWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4"})}, {"lax whitelist policy", MustLaxWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4"})}, } upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.3:45738") if err != nil { t.Fatalf("err: %v", err) } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { policy, err := tc.policy(upstream) if err != nil { t.Fatalf("err: %v", err) } if policy != USE { t.Fatalf("Expected policy USE, got %v", policy) } }) } } func TestWhitelistPolicyReturnsUseWhenUpstreamIpAddrInWhitelistRange(t *testing.T) { var cases = []struct { name string policy PolicyFunc }{ {"strict whitelist policy", MustStrictWhiteListPolicy([]string{"10.0.0.0/29"})}, {"lax whitelist policy", MustLaxWhiteListPolicy([]string{"10.0.0.0/29"})}, } upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.3:45738") if err != nil { t.Fatalf("err: %v", err) } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { policy, err := tc.policy(upstream) if err != nil { t.Fatalf("err: %v", err) } if policy != USE { t.Fatalf("Expected policy USE, got %v", policy) } }) } } func Test_CreateWhitelistPolicyWithInvalidCidrReturnsError(t *testing.T) { var cases = []struct { name string fn func() error }{ {"strict whitelist policy", func() error { _, err := StrictWhiteListPolicy([]string{"20/80"}) return err }}, {"lax whitelist policy", func() error { _, err := LaxWhiteListPolicy([]string{"20/80"}) return err }}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { if err := tc.fn(); err == nil { t.Error("Expected error, got none") } }) } } func Test_CreateWhitelistPolicyWithInvalidIpAddressReturnsError(t *testing.T) { var cases = []struct { name string fn func() error }{ {"strict whitelist policy", func() error { _, err := StrictWhiteListPolicy([]string{"855.222.233.11"}) return err }}, {"lax whitelist policy", func() error { _, err := LaxWhiteListPolicy([]string{"855.222.233.11"}) return err }}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { if err := tc.fn(); err == nil { t.Error("Expected error, got none") } }) } } func Test_MustLaxWhiteListPolicyPanicsWithInvalidIpAddress(t *testing.T) { defer func() { if r := recover(); r == nil { t.Error("Expected a panic, but got none") } }() MustLaxWhiteListPolicy([]string{"855.222.233.11"}) } func Test_MustLaxWhiteListPolicyPanicsWithInvalidIpRange(t *testing.T) { defer func() { if r := recover(); r == nil { t.Error("Expected a panic, but got none") } }() MustLaxWhiteListPolicy([]string{"20/80"}) } func Test_MustStrictWhiteListPolicyPanicsWithInvalidIpAddress(t *testing.T) { defer func() { if r := recover(); r == nil { t.Error("Expected a panic, but got none") } }() MustStrictWhiteListPolicy([]string{"855.222.233.11"}) } func Test_MustStrictWhiteListPolicyPanicsWithInvalidIpRange(t *testing.T) { defer func() { if r := recover(); r == nil { t.Error("Expected a panic, but got none") } }() MustStrictWhiteListPolicy([]string{"20/80"}) } func TestWhiteListPolicyFuncsReturnPolicies(t *testing.T) { strictPolicy, err := StrictWhiteListPolicy([]string{"10.0.0.3"}) if err != nil { t.Fatalf("err: %v", err) } laxPolicy, err := LaxWhiteListPolicy([]string{"10.0.0.3"}) if err != nil { t.Fatalf("err: %v", err) } upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.3:45738") if err != nil { t.Fatalf("err: %v", err) } policy, err := strictPolicy(upstream) if err != nil { t.Fatalf("err: %v", err) } if policy != USE { t.Fatalf("Expected policy USE, got %v", policy) } policy, err = laxPolicy(upstream) if err != nil { t.Fatalf("err: %v", err) } if policy != USE { t.Fatalf("Expected policy USE, got %v", policy) } } func TestSkipProxyHeaderForCIDR(t *testing.T) { _, cidr, _ := net.ParseCIDR("192.0.2.1/24") f := SkipProxyHeaderForCIDR(cidr, REJECT) upstream, _ := net.ResolveTCPAddr("tcp", "192.0.2.255:12345") policy, err := f(upstream) if err != nil { t.Fatalf("err: %v", err) } if policy != SKIP { t.Errorf("Expected a SKIP policy for the %s address", upstream) } upstream, _ = net.ResolveTCPAddr("tcp", "8.8.8.8:12345") policy, err = f(upstream) if err != nil { t.Fatalf("err: %v", err) } if policy != REJECT { t.Errorf("Expected a REJECT policy for the %s address", upstream) } } func TestConnSkipProxyHeaderForCIDRReturnsErrorOnInvalidAddress(t *testing.T) { _, cidr, _ := net.ParseCIDR("192.0.2.1/24") policy := ConnSkipProxyHeaderForCIDR(cidr, IGNORE) result, err := policy(ConnPolicyOptions{Upstream: failingAddr{}}) if err == nil { t.Fatal("Expected error, got none") } if result != IGNORE { t.Fatalf("Expected policy IGNORE, got %v", result) } } func TestConnSkipProxyHeaderForCIDR(t *testing.T) { _, cidr, _ := net.ParseCIDR("192.0.2.1/24") policy := ConnSkipProxyHeaderForCIDR(cidr, REJECT) upstream, _ := net.ResolveTCPAddr("tcp", "192.0.2.255:12345") result, err := policy(ConnPolicyOptions{Upstream: upstream}) if err != nil { t.Fatalf("err: %v", err) } if result != SKIP { t.Errorf("Expected a SKIP policy for the %s address", upstream) } upstream, _ = net.ResolveTCPAddr("tcp", "8.8.8.8:12345") result, err = policy(ConnPolicyOptions{Upstream: upstream}) if err != nil { t.Fatalf("err: %v", err) } if result != REJECT { t.Errorf("Expected a REJECT policy for the %s address", upstream) } } func TestConnWhitelistPolicies(t *testing.T) { var cases = []struct { name string policy ConnPolicyFunc expectedUse Policy expectedReject Policy }{ {"conn strict whitelist policy", ConnMustStrictWhiteListPolicy([]string{"10.0.0.3"}), USE, REJECT}, {"conn lax whitelist policy", ConnMustLaxWhiteListPolicy([]string{"10.0.0.3"}), USE, IGNORE}, } allowed, err := net.ResolveTCPAddr("tcp", "10.0.0.3:45738") if err != nil { t.Fatalf("err: %v", err) } denied, err := net.ResolveTCPAddr("tcp", "10.0.0.4:45738") if err != nil { t.Fatalf("err: %v", err) } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { policy, err := tc.policy(ConnPolicyOptions{Upstream: allowed}) if err != nil { t.Fatalf("err: %v", err) } if policy != tc.expectedUse { t.Fatalf("Expected policy %v, got %v", tc.expectedUse, policy) } policy, err = tc.policy(ConnPolicyOptions{Upstream: denied}) if err != nil { t.Fatalf("err: %v", err) } if policy != tc.expectedReject { t.Fatalf("Expected policy %v, got %v", tc.expectedReject, policy) } }) } } func TestTrustProxyHeaderFrom(t *testing.T) { upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.3:45738") if err != nil { t.Fatalf("err: %v", err) } var cases = []struct { name string policy ConnPolicyFunc upstreamAddr net.Addr expectedPolicy Policy expectError bool }{ {"reject header from untrusted source", TrustProxyHeaderFrom(net.ParseIP("192.0.2.1")), upstream, REJECT, false}, {"use header from trusted load balancer", TrustProxyHeaderFrom(net.ParseIP("10.0.0.3")), upstream, USE, false}, {"use header when source matches any trusted IP", TrustProxyHeaderFrom(net.ParseIP("192.0.2.1"), net.ParseIP("10.0.0.3")), upstream, USE, false}, {"invalid address should return error", TrustProxyHeaderFrom(net.ParseIP("10.0.0.3")), failingAddr{}, REJECT, true}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { policy, err := tc.policy(ConnPolicyOptions{ Upstream: tc.upstreamAddr, }) if !tc.expectError && err != nil { t.Fatalf("err: %v", err) } if tc.expectError && err == nil { t.Fatal("Expected error, got none") } if policy != tc.expectedPolicy { t.Fatalf("Expected policy %v, got %v", tc.expectedPolicy, policy) } }) } } func TestIgnoreProxyHeaderNotOnInterface(t *testing.T) { downstream, err := net.ResolveTCPAddr("tcp", "10.0.0.3:45738") if err != nil { t.Fatalf("err: %v", err) } var cases = []struct { name string policy ConnPolicyFunc downstreamAddress net.Addr expectedPolicy Policy expectError bool }{ {"ignore header for requests not on interface", IgnoreProxyHeaderNotOnInterface(net.ParseIP("192.0.2.1")), downstream, IGNORE, false}, {"use headers for requests on interface", IgnoreProxyHeaderNotOnInterface(net.ParseIP("10.0.0.3")), downstream, USE, false}, {"invalid address should return error", IgnoreProxyHeaderNotOnInterface(net.ParseIP("10.0.0.3")), failingAddr{}, REJECT, true}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { policy, err := tc.policy(ConnPolicyOptions{ Downstream: tc.downstreamAddress, }) if !tc.expectError && err != nil { t.Fatalf("err: %v", err) } if tc.expectError && err == nil { t.Fatal("Expected error, got none") } if policy != tc.expectedPolicy { t.Fatalf("Expected policy %v, got %v", tc.expectedPolicy, policy) } }) } } pires-go-proxyproto-04c9ad1/protocol.go000066400000000000000000000351031514137054000203030ustar00rootroot00000000000000package proxyproto import ( "bufio" "errors" "fmt" "io" "net" "sync" "sync/atomic" "time" ) // readBufferSize is the size used for bufio.Reader's internal buffer. // // This is kept low to reduce per-connection memory overhead. If the header is // larger than readBufferSize, the header will be decoded with multiple Read // calls. For v1 the header length is at most 108 bytes. For v2 the header // length is at most 52 bytes plus the length of the TLVs. We use 256 bytes to // accommodate for the most common cases. const readBufferSize = 256 var ( // DefaultReadHeaderTimeout is how long header processing waits for header to // be read from the wire, if Listener.ReaderHeaderTimeout is not set. // It's kept as a global variable so to make it easier to find and override, // e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s". DefaultReadHeaderTimeout = 10 * time.Second // ErrInvalidUpstream should be returned when an upstream connection address // is not trusted, and therefore is invalid. ErrInvalidUpstream = fmt.Errorf("proxyproto: upstream connection address not trusted for PROXY information") ) // Listener is used to wrap an underlying listener, // whose connections may be using the HAProxy Proxy Protocol. // If the connection is using the protocol, the RemoteAddr() will return // the correct client address. ReadHeaderTimeout will be applied to all // connections in order to prevent blocking operations. If no ReadHeaderTimeout // is set, a default of 10s will be used. This can be disabled by setting the // timeout to < 0. // // Only one of Policy or ConnPolicy should be provided. If both are provided then // a panic would occur during accept. type Listener struct { // Listener is the underlying listener. Listener net.Listener // Deprecated: use ConnPolicyFunc instead. This will be removed in future release. Policy PolicyFunc // ConnPolicy is the policy function for accepted connections. ConnPolicy ConnPolicyFunc // ValidateHeader is the validator function for the proxy header. ValidateHeader Validator // ReadHeaderTimeout is the timeout for reading the proxy header. ReadHeaderTimeout time.Duration // ReadBufferSize is the read buffer size for accepted connections. When > 0, // each accepted connection uses this size for proxy header detection; 0 means default. ReadBufferSize int } // Conn is used to wrap and underlying connection which // may be speaking the Proxy Protocol. If it is, the RemoteAddr() will // return the address of the client instead of the proxy address. Each connection // will have its own readHeaderTimeout and readDeadline set by the Accept() call. type Conn struct { readDeadline atomic.Value // time.Time once sync.Once readErr error conn net.Conn bufReader *bufio.Reader // bufferSize is set when the client overrides via WithBufferSize; nil means use default. bufferSize *int header *Header ProxyHeaderPolicy Policy Validate Validator readHeaderTimeout time.Duration } // Validator receives a header and decides whether it is a valid one // In case the header is not deemed valid it should return an error. type Validator func(*Header) error // ValidateHeader adds given validator for proxy headers to a connection when passed as option to NewConn(). func ValidateHeader(v Validator) func(*Conn) { return func(c *Conn) { if v != nil { c.Validate = v } } } // SetReadHeaderTimeout sets the readHeaderTimeout for a connection when passed as option to NewConn(). func SetReadHeaderTimeout(t time.Duration) func(*Conn) { return func(c *Conn) { if t >= 0 { c.readHeaderTimeout = t } } } // WithBufferSize sets the size of the read buffer used for proxy header detection. // Values <= 0 are ignored and the default (256 bytes) is used. Values < 16 are // effectively 16 due to bufio's minimum. The default is tuned for typical proxy // protocol header lengths. func WithBufferSize(length int) func(*Conn) { return func(c *Conn) { if length <= 0 { return } p := new(int) *p = length c.bufferSize = p c.bufReader = bufio.NewReaderSize(c.conn, length) } } // Accept waits for and returns the next valid connection to the listener. func (p *Listener) Accept() (net.Conn, error) { for { // Get the underlying connection. conn, err := p.Listener.Accept() if err != nil { return nil, err } proxyHeaderPolicy := USE if p.Policy != nil && p.ConnPolicy != nil { panic("only one of policy or connpolicy must be provided.") } if p.Policy != nil || p.ConnPolicy != nil { if p.Policy != nil { proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr()) } else { proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{ Upstream: conn.RemoteAddr(), Downstream: conn.LocalAddr(), }) } if err != nil { // can't decide the policy, we can't accept the connection. if closeErr := conn.Close(); closeErr != nil { return nil, closeErr } if errors.Is(err, ErrInvalidUpstream) { // keep listening for other connections. continue } return nil, err } // Handle a connection as a regular one. if proxyHeaderPolicy == SKIP { return conn, nil } } opts := []func(*Conn){ WithPolicy(proxyHeaderPolicy), ValidateHeader(p.ValidateHeader), } if p.ReadBufferSize > 0 { opts = append(opts, WithBufferSize(p.ReadBufferSize)) } newConn := NewConn(conn, opts...) // If the ReadHeaderTimeout for the listener is unset, use the default timeout. if p.ReadHeaderTimeout == 0 { p.ReadHeaderTimeout = DefaultReadHeaderTimeout } // Set the readHeaderTimeout of the new conn to the value of the listener newConn.readHeaderTimeout = p.ReadHeaderTimeout return newConn, nil } } // Close closes the underlying listener. func (p *Listener) Close() error { return p.Listener.Close() } // Addr returns the underlying listener's network address. func (p *Listener) Addr() net.Addr { return p.Listener.Addr() } // NewConn is used to wrap a net.Conn that may be speaking the PROXY protocol // into a proxyproto.Conn. // // NOTE: NewConn may interfere with previously set ReadDeadline on the provided net.Conn, // because it sets a temporary deadline when detecting and reading the PROXY protocol header. // If you need to enforce a specific ReadDeadline on the connection, be sure to call Conn.SetReadDeadline // again after NewConn returns, to restore your desired deadline. func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn { br := bufio.NewReaderSize(conn, readBufferSize) pConn := &Conn{ bufReader: br, conn: conn, } for _, opt := range opts { opt(pConn) } return pConn } // Read is check for the proxy protocol header when doing // the initial scan. If there is an error parsing the header, // it is returned and the socket is closed. func (p *Conn) Read(b []byte) (int, error) { // Ensure header processing runs at most once and surface any errors. if err := p.ensureHeaderProcessed(); err != nil { return 0, err } // Drain the buffer if it exists and has data. if p.bufReader != nil { if p.bufReader.Buffered() > 0 { n, err := p.bufReader.Read(b) // Did we empty the buffer? // Buffering a net.Conn means the buffer doesn't return io.EOF until the connection returns io.EOF. // Therefore, we use Buffered() == 0 to detect if we are done with the buffer. if p.bufReader.Buffered() == 0 { // Garbage collect the buffer. p.bufReader = nil } // Return immediately. Do not touch p.conn. // If err is EOF here, it means the connection is actually closed, // so we should return that error to the user anyway. return n, err } // If buffer was empty to begin with (shouldn't happen with the >0 check // but good for safety), clear it. p.bufReader = nil } // From now on, read directly from the underlying connection. return p.conn.Read(b) } // Write wraps original conn.Write. func (p *Conn) Write(b []byte) (int, error) { // Ensure header processing has completed before writing. if err := p.ensureHeaderProcessed(); err != nil { return 0, err } return p.conn.Write(b) } // Close wraps original conn.Close. func (p *Conn) Close() error { return p.conn.Close() } // ProxyHeader returns the proxy protocol header, if any. If an error occurs // while reading the proxy header, nil is returned. func (p *Conn) ProxyHeader() *Header { // Ensure header processing runs at most once. _ = p.ensureHeaderProcessed() return p.header } // LocalAddr returns the address of the server if the proxy // protocol is being used, otherwise just returns the address of // the socket server. In case an error happens on reading the // proxy header the original LocalAddr is returned, not the one // from the proxy header even if the proxy header itself is // syntactically correct. func (p *Conn) LocalAddr() net.Addr { // Ensure header processing runs at most once. _ = p.ensureHeaderProcessed() if p.header == nil || p.header.Command.IsLocal() || p.readErr != nil { return p.conn.LocalAddr() } return p.header.DestinationAddr } // RemoteAddr returns the address of the client if the proxy // protocol is being used, otherwise just returns the address of // the socket peer. In case an error happens on reading the // proxy header the original RemoteAddr is returned, not the one // from the proxy header even if the proxy header itself is // syntactically correct. func (p *Conn) RemoteAddr() net.Addr { // Ensure header processing runs at most once. _ = p.ensureHeaderProcessed() if p.header == nil || p.header.Command.IsLocal() || p.readErr != nil { return p.conn.RemoteAddr() } return p.header.SourceAddr } // Raw returns the underlying connection which can be casted to // a concrete type, allowing access to specialized functions. // // Use this ONLY if you know exactly what you are doing. func (p *Conn) Raw() net.Conn { return p.conn } // TCPConn returns the underlying TCP connection, // allowing access to specialized functions. // // Use this ONLY if you know exactly what you are doing. func (p *Conn) TCPConn() (conn *net.TCPConn, ok bool) { conn, ok = p.conn.(*net.TCPConn) return } // UnixConn returns the underlying Unix socket connection, // allowing access to specialized functions. // // Use this ONLY if you know exactly what you are doing. func (p *Conn) UnixConn() (conn *net.UnixConn, ok bool) { conn, ok = p.conn.(*net.UnixConn) return } // UDPConn returns the underlying UDP connection, // allowing access to specialized functions. // // Use this ONLY if you know exactly what you are doing. func (p *Conn) UDPConn() (conn *net.UDPConn, ok bool) { conn, ok = p.conn.(*net.UDPConn) return } // SetDeadline wraps original conn.SetDeadline. func (p *Conn) SetDeadline(t time.Time) error { p.readDeadline.Store(t) return p.conn.SetDeadline(t) } // SetReadDeadline wraps original conn.SetReadDeadline. func (p *Conn) SetReadDeadline(t time.Time) error { // Set a local var that tells us the desired deadline. This is // needed in order to reset the read deadline to the one that is // desired by the user, rather than an empty deadline. p.readDeadline.Store(t) return p.conn.SetReadDeadline(t) } // SetWriteDeadline wraps original conn.SetWriteDeadline. func (p *Conn) SetWriteDeadline(t time.Time) error { return p.conn.SetWriteDeadline(t) } // readHeader reads the proxy protocol header from the connection. func (p *Conn) readHeader() error { // If the connection's readHeaderTimeout is more than 0, // apply a temporary deadline without extending a user-configured // deadline. If the user has no deadline, we use now + timeout. if p.readHeaderTimeout > 0 { var ( storedDeadline time.Time hasDeadline bool ) if t := p.readDeadline.Load(); t != nil { storedDeadline = t.(time.Time) hasDeadline = !storedDeadline.IsZero() } headerDeadline := time.Now().Add(p.readHeaderTimeout) if hasDeadline && storedDeadline.Before(headerDeadline) { // Clamp to the user's earlier deadline to avoid extending it. headerDeadline = storedDeadline } if err := p.conn.SetReadDeadline(headerDeadline); err != nil { return err } } header, err := Read(p.bufReader) // If the connection's readHeaderTimeout is more than 0, undo the change to the // deadline that we made above. Because we retain the readDeadline as part of our // SetReadDeadline override, we can restore the user's deadline (if any). // Therefore, we check whether the error is a net.Timeout and if it is, we decide // the proxy proto does not exist and set the error accordingly. if p.readHeaderTimeout > 0 { t := p.readDeadline.Load() if t == nil { t = time.Time{} } if err := p.conn.SetReadDeadline(t.(time.Time)); err != nil { return err } if netErr, ok := err.(net.Error); ok && netErr.Timeout() { err = ErrNoProxyProtocol } } // For the purpose of this wrapper shamefully stolen from armon/go-proxyproto // let's act as if there was no error when PROXY protocol is not present. if err == ErrNoProxyProtocol { // but not if it is required that the connection has one if p.ProxyHeaderPolicy == REQUIRE { return err } return nil } // proxy protocol header was found if err == nil && header != nil { switch p.ProxyHeaderPolicy { case REJECT: // this connection is not allowed to send one return ErrSuperfluousProxyHeader case USE, REQUIRE: if p.Validate != nil { err = p.Validate(header) if err != nil { return err } } p.header = header } } return err } // ensureHeaderProcessed runs header processing once. func (p *Conn) ensureHeaderProcessed() error { p.once.Do(func() { p.readErr = p.readHeader() }) if p.readErr != nil { return p.readErr } return nil } // ReadFrom implements the io.ReaderFrom ReadFrom method. func (p *Conn) ReadFrom(r io.Reader) (int64, error) { // Ensure header processing has completed before reading/writing. if err := p.ensureHeaderProcessed(); err != nil { return 0, err } if rf, ok := p.conn.(io.ReaderFrom); ok { return rf.ReadFrom(r) } return io.Copy(p.conn, r) } // WriteTo implements io.WriterTo. func (p *Conn) WriteTo(w io.Writer) (int64, error) { // Ensure header processing has completed before reading/writing. if err := p.ensureHeaderProcessed(); err != nil { return 0, err } // If the buffer has been drained (or cleared), copy directly from conn. if p.bufReader == nil { return io.Copy(w, p.conn) } b := make([]byte, p.bufReader.Buffered()) if _, err := p.bufReader.Read(b); err != nil { return 0, err // this should never happen as we read buffered data. } var n int64 { nn, err := w.Write(b) n += int64(nn) if err != nil { return n, err } } { nn, err := io.Copy(w, p.conn) n += nn if err != nil { return n, err } } return n, nil } pires-go-proxyproto-04c9ad1/protocol_test.go000066400000000000000000002067501514137054000213520ustar00rootroot00000000000000// This file was shamefully stolen from github.com/armon/go-proxyproto. // It has been heavily edited to conform to this lib. // // Thanks @armon package proxyproto import ( "bytes" "crypto/tls" "crypto/x509" "errors" "fmt" "io" "net" "net/http" "path/filepath" "sync/atomic" "testing" "time" ) // testSourceIPv4Addr is a source IPv4 address used in tests. const testSourceIPv4Addr string = "10.1.1.1" // testDestinationIPv4Addr is the destination IPv4 address used in tests. const testDestinationIPv4Addr string = "20.2.2.2" // testLocalhostRandomPort is a localhost random port used in tests. const testLocalhostRandomPort string = "127.0.0.1:0" func TestPassthrough(t *testing.T) { l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } pl := &Listener{Listener: l} cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { cliResult <- err return } if !bytes.Equal(recv, []byte("pong")) { cliResult <- fmt.Errorf("bad: %v", recv) return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) recv := make([]byte, 4) _, err = conn.Read(recv) if err != nil { t.Fatalf("err: %v", err) } if !bytes.Equal(recv, []byte("ping")) { t.Fatalf("bad: %v", recv) } if _, err := conn.Write([]byte("pong")); err != nil { t.Fatalf("err: %v", err) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } // TestRequiredWithReadHeaderTimeout will iterate through 3 different timeouts to see // whether using a REQUIRE policy for a listener would cause an error if the timeout // is triggerred without a proxy protocol header being defined. func TestRequiredWithReadHeaderTimeout(t *testing.T) { for _, duration := range []int{100, 200, 400} { t.Run(fmt.Sprint(duration), func(t *testing.T) { start := time.Now() l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } pl := &Listener{ Listener: l, ReadHeaderTimeout: time.Millisecond * time.Duration(duration), Policy: func(_ net.Addr) (Policy, error) { return REQUIRE, nil }, } cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) // Read blocks forever if there is no ReadHeaderTimeout and the policy is not REQUIRE recv := make([]byte, 4) _, err = conn.Read(recv) if err != nil && !errors.Is(err, ErrNoProxyProtocol) && time.Since(start)-pl.ReadHeaderTimeout > 10*time.Millisecond { t.Fatal("proxy proto should not be found and time should be close to read timeout") } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } }) } } // TestUseWithReadHeaderTimeout will iterate through 3 different timeouts to see // whether using a USE policy for a listener would not cause an error if the timeout // is triggerred without a proxy protocol header being defined. func TestUseWithReadHeaderTimeout(t *testing.T) { for _, duration := range []int{100, 200, 400} { t.Run(fmt.Sprint(duration), func(t *testing.T) { start := time.Now() l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } pl := &Listener{ Listener: l, ReadHeaderTimeout: time.Millisecond * time.Duration(duration), Policy: func(_ net.Addr) (Policy, error) { return USE, nil }, } cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) // 2 times the ReadHeaderTimeout because the first timeout // should occur (the one set on the listener) and allow for the second to follow up if err := conn.SetDeadline(time.Now().Add(pl.ReadHeaderTimeout * 2)); err != nil { t.Fatalf("err: %v", err) } // Read blocks forever if there is no ReadHeaderTimeout recv := make([]byte, 4) _, err = conn.Read(recv) if err != nil && !errors.Is(err, ErrNoProxyProtocol) && (time.Since(start)-(pl.ReadHeaderTimeout*2)) > 10*time.Millisecond { t.Fatal("proxy proto should not be found and time should be close to read timeout") } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } }) } } func TestNewConnSetReadHeaderTimeoutOption(t *testing.T) { conn, peer := net.Pipe() t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) t.Cleanup(func() { if closeErr := peer.Close(); closeErr != nil { t.Errorf("failed to close peer connection: %v", closeErr) } }) // Ensure SetReadHeaderTimeout sets the connection-specific timeout. timeout := 150 * time.Millisecond proxyConn := NewConn(conn, SetReadHeaderTimeout(timeout)) if proxyConn.readHeaderTimeout != timeout { t.Fatalf("expected readHeaderTimeout %v, got %v", timeout, proxyConn.readHeaderTimeout) } } func TestNewConnSetReadHeaderTimeoutIgnoresNegative(t *testing.T) { conn, peer := net.Pipe() t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) t.Cleanup(func() { if closeErr := peer.Close(); closeErr != nil { t.Errorf("failed to close peer connection: %v", closeErr) } }) // Negative values should be ignored, leaving the timeout unset. proxyConn := NewConn(conn, SetReadHeaderTimeout(-1)) if proxyConn.readHeaderTimeout != 0 { t.Fatalf("expected readHeaderTimeout to remain 0, got %v", proxyConn.readHeaderTimeout) } } func TestWithBufferSizePositive(t *testing.T) { conn, peer := net.Pipe() t.Cleanup(func() { _ = conn.Close() _ = peer.Close() }) proxyConn := NewConn(conn, WithBufferSize(4096)) if proxyConn.bufferSize == nil { t.Fatalf("expected bufferSize to be set") } if *proxyConn.bufferSize != 4096 { t.Fatalf("expected bufferSize 4096, got %d", *proxyConn.bufferSize) } go func() { _, _ = peer.Write([]byte("x")) }() buf := make([]byte, 1) if _, err := proxyConn.Read(buf); err != nil { t.Fatalf("read failed: %v", err) } if string(buf) != "x" { t.Fatalf("unexpected read: %q", buf) } } func TestWithBufferSizeZeroOrNegative(t *testing.T) { for _, length := range []int{0, -1} { t.Run(fmt.Sprint(length), func(t *testing.T) { conn, peer := net.Pipe() t.Cleanup(func() { _ = conn.Close() _ = peer.Close() }) proxyConn := NewConn(conn, WithBufferSize(length)) if proxyConn.bufferSize != nil { t.Fatalf("expected bufferSize to be nil for length %d", length) } go func() { _, _ = peer.Write([]byte("y")) }() buf := make([]byte, 1) if _, err := proxyConn.Read(buf); err != nil { t.Fatalf("read failed: %v", err) } if string(buf) != "y" { t.Fatalf("unexpected read: %q", buf) } }) } } func TestListenerReadBufferSizeApplied(t *testing.T) { l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { _ = l.Close() }) pl := &Listener{Listener: l, ReadBufferSize: 4096} go func() { c, _ := net.Dial("tcp", pl.Addr().String()) if c != nil { _ = c.Close() } }() conn, err := pl.Accept() if err != nil { t.Fatalf("Accept: %v", err) } t.Cleanup(func() { _ = conn.Close() }) proxyConn := conn.(*Conn) if proxyConn.bufferSize == nil { t.Fatalf("expected bufferSize to be set when Listener.ReadBufferSize > 0") } if *proxyConn.bufferSize != 4096 { t.Fatalf("expected bufferSize 4096, got %d", *proxyConn.bufferSize) } } func TestListenerReadBufferSizeZeroUsesDefault(t *testing.T) { l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { _ = l.Close() }) pl := &Listener{Listener: l, ReadBufferSize: 0} go func() { c, _ := net.Dial("tcp", pl.Addr().String()) if c != nil { _ = c.Close() } }() conn, err := pl.Accept() if err != nil { t.Fatalf("Accept: %v", err) } t.Cleanup(func() { _ = conn.Close() }) proxyConn := conn.(*Conn) if proxyConn.bufferSize != nil { t.Fatalf("expected bufferSize to be nil when Listener.ReadBufferSize is 0") } } func TestReadHeaderTimeoutRespectsEarlierDeadline(t *testing.T) { const ( headerTimeout = 200 * time.Millisecond userTimeout = 60 * time.Millisecond tolerance = 100 * time.Millisecond ) l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } pl := &Listener{ Listener: l, ReadHeaderTimeout: headerTimeout, Policy: func(_ net.Addr) (Policy, error) { // Use REQUIRE so a timeout is surfaced as ErrNoProxyProtocol. return REQUIRE, nil }, } type dialResult struct { conn net.Conn err error } dialResultCh := make(chan dialResult, 1) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) dialResultCh <- dialResult{conn: conn, err: err} }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) result := <-dialResultCh if result.err != nil { t.Fatalf("client error: %v", result.err) } t.Cleanup(func() { if closeErr := result.conn.Close(); closeErr != nil { t.Errorf("failed to close client connection: %v", closeErr) } }) // Set a shorter user deadline than the readHeaderTimeout and do not send data. if err := conn.SetReadDeadline(time.Now().Add(userTimeout)); err != nil { t.Fatalf("err: %v", err) } start := time.Now() recv := make([]byte, 1) _, err = conn.Read(recv) elapsed := time.Since(start) // The read should honor the earlier user deadline instead of waiting // for the longer readHeaderTimeout. if !errors.Is(err, ErrNoProxyProtocol) { t.Fatalf("expected ErrNoProxyProtocol, got: %v", err) } if elapsed > userTimeout+tolerance { t.Fatalf("read exceeded user deadline: elapsed=%v timeout=%v", elapsed, userTimeout) } } func TestDeadlineSettersAfterHeaderProcessed(t *testing.T) { conn, peer := net.Pipe() t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) t.Cleanup(func() { if closeErr := peer.Close(); closeErr != nil { t.Errorf("failed to close peer connection: %v", closeErr) } }) proxyConn := NewConn(conn) // Ensure header processing completes by sending a non-PROXY byte // and reading it through the proxy connection. go func() { if _, err := peer.Write([]byte("x")); err != nil { t.Errorf("failed to write peer data: %v", err) } }() buf := make([]byte, 1) if _, err := proxyConn.Read(buf); err != nil { t.Fatalf("read failed: %v", err) } deadline := time.Now().Add(time.Second) if err := proxyConn.SetDeadline(deadline); err != nil { t.Fatalf("unexpected SetDeadline error: %v", err) } if err := proxyConn.SetReadDeadline(deadline); err != nil { t.Fatalf("unexpected SetReadDeadline error: %v", err) } if err := proxyConn.SetWriteDeadline(deadline); err != nil { t.Fatalf("unexpected SetWriteDeadline error: %v", err) } } func TestReadHeaderTimeoutIsReset(t *testing.T) { const timeout = time.Millisecond * 250 l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } pl := &Listener{ Listener: l, ReadHeaderTimeout: timeout, } header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, } cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) // Write out the header! if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } // Sleep here longer than the configured timeout. time.Sleep(timeout * 2) if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } recv := make([]byte, 4) if _, err := conn.Read(recv); err != nil { cliResult <- err return } if !bytes.Equal(recv, []byte("pong")) { cliResult <- fmt.Errorf("bad: %v", recv) return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) // Set our deadlines higher than our ReadHeaderTimeout if err := conn.SetReadDeadline(time.Now().Add(timeout * 3)); err != nil { t.Fatalf("err: %v", err) } if err := conn.SetWriteDeadline(time.Now().Add(timeout * 3)); err != nil { t.Fatalf("err: %v", err) } recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { t.Fatalf("err: %v", err) } if !bytes.Equal(recv, []byte("ping")) { t.Fatalf("bad: %v", recv) } if _, err := conn.Write([]byte("pong")); err != nil { t.Fatalf("err: %v", err) } // Check the remote addr addr := conn.RemoteAddr().(*net.TCPAddr) if addr.IP.String() != testSourceIPv4Addr { t.Fatalf("bad: %v", addr) } if addr.Port != 1000 { t.Fatalf("bad: %v", addr) } h := conn.(*Conn).ProxyHeader() if !h.EqualsTo(header) { t.Errorf("bad: %v", h) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } // TestReadHeaderTimeoutIsEmpty ensures the default is set if it is empty. // The default is 10s, but we delay sending a message, so use 200ms in this test. // We expect the actual address and port to be returned, // rather than the ProxyHeader we defined. func TestReadHeaderTimeoutIsEmpty(t *testing.T) { DefaultReadHeaderTimeout = 200 * time.Millisecond l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } pl := &Listener{ Listener: l, } header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, } cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) // Sleep here longer than the configured timeout. time.Sleep(250 * time.Millisecond) // Write out the header! if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { t.Fatalf("err: %v", err) } // Check the remote addr addr := conn.RemoteAddr().(*net.TCPAddr) if addr.IP.String() == testSourceIPv4Addr { t.Fatalf("bad: %v", addr) } if addr.Port == 1000 { t.Fatalf("bad: %v", addr) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } // TestReadHeaderTimeoutIsNegative does the same as above except // with a negative timeout. Therefore, we expect the right ProxyHeader // to be returned. func TestReadHeaderTimeoutIsNegative(t *testing.T) { l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } pl := &Listener{ Listener: l, ReadHeaderTimeout: -1, } header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, } cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) // Sleep here longer than the configured timeout. time.Sleep(250 * time.Millisecond) // Write out the header! if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { t.Fatalf("err: %v", err) } // Check the remote addr addr := conn.RemoteAddr().(*net.TCPAddr) if addr.IP.String() != testSourceIPv4Addr { t.Fatalf("bad: %v", addr) } if addr.Port != 1000 { t.Fatalf("bad: %v", addr) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestParse_ipv4(t *testing.T) { l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } pl := &Listener{Listener: l} header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, } cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) // Write out the header! if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { cliResult <- err return } if !bytes.Equal(recv, []byte("pong")) { cliResult <- fmt.Errorf("bad: %v", recv) return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { t.Fatalf("err: %v", err) } if !bytes.Equal(recv, []byte("ping")) { t.Fatalf("bad: %v", recv) } if _, err := conn.Write([]byte("pong")); err != nil { t.Fatalf("err: %v", err) } // Check the remote addr addr := conn.RemoteAddr().(*net.TCPAddr) if addr.IP.String() != testSourceIPv4Addr { t.Fatalf("bad: %v", addr) } if addr.Port != 1000 { t.Fatalf("bad: %v", addr) } h := conn.(*Conn).ProxyHeader() if !h.EqualsTo(header) { t.Errorf("bad: %v", h) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestParse_unixStream(t *testing.T) { socketDir := t.TempDir() socketPath := filepath.Join(socketDir, "proxy.sock") l, err := net.Listen("unix", socketPath) if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := l.Close(); closeErr != nil { t.Errorf("failed to close listener: %v", closeErr) } }) pl := &Listener{Listener: l} header := &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixStream, SourceAddr: &net.UnixAddr{ Net: "unix", Name: "source.sock", }, DestinationAddr: &net.UnixAddr{ Net: "unix", Name: "dest.sock", }, } cliResult := make(chan error) go func() { conn, err := net.Dial("unix", socketPath) if err != nil { cliResult <- err return } defer func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }() // Write out the header! if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { cliResult <- err return } if !bytes.Equal(recv, []byte("pong")) { cliResult <- fmt.Errorf("bad: %v", recv) return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { t.Fatalf("err: %v", err) } if !bytes.Equal(recv, []byte("ping")) { t.Fatalf("bad: %v", recv) } if _, err := conn.Write([]byte("pong")); err != nil { t.Fatalf("err: %v", err) } addr := conn.RemoteAddr().(*net.UnixAddr) if addr.Name != header.SourceAddr.(*net.UnixAddr).Name { t.Fatalf("bad: %v", addr) } h := conn.(*Conn).ProxyHeader() if !h.EqualsTo(header) { t.Errorf("bad: %v", h) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestParse_unixDatagram(t *testing.T) { server, client := net.Pipe() t.Cleanup(func() { if closeErr := client.Close(); closeErr != nil { t.Errorf("failed to close client: %v", closeErr) } }) t.Cleanup(func() { if closeErr := server.Close(); closeErr != nil { t.Errorf("failed to close server: %v", closeErr) } }) header := &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixDatagram, SourceAddr: &net.UnixAddr{ Net: "unixgram", Name: "source.sock", }, DestinationAddr: &net.UnixAddr{ Net: "unixgram", Name: "dest.sock", }, } go func() { defer func() { if closeErr := client.Close(); closeErr != nil { t.Errorf("failed to close client: %v", closeErr) } }() if _, err := header.WriteTo(client); err != nil { t.Errorf("failed to write header: %v", err) } if _, err := client.Write([]byte("ping")); err != nil { t.Errorf("failed to write ping: %v", err) } }() conn := NewConn(server) recv := make([]byte, 4) if _, err := conn.Read(recv); err != nil { t.Fatalf("err: %v", err) } if !bytes.Equal(recv, []byte("ping")) { t.Fatalf("bad: %v", recv) } remoteAddr := conn.RemoteAddr().(*net.UnixAddr) if remoteAddr.Name != header.SourceAddr.(*net.UnixAddr).Name { t.Fatalf("bad: %v", remoteAddr) } localAddr := conn.LocalAddr().(*net.UnixAddr) if localAddr.Name != header.DestinationAddr.(*net.UnixAddr).Name { t.Fatalf("bad: %v", localAddr) } h := conn.ProxyHeader() if !h.EqualsTo(header) { t.Errorf("bad: %v", h) } } func TestParse_ipv6(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } pl := &Listener{Listener: l} header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("ffff::ffff"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("ffff::ffff"), Port: 2000, }, } cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) // Write out the header! if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { cliResult <- err return } if !bytes.Equal(recv, []byte("pong")) { cliResult <- fmt.Errorf("bad: %v", recv) return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { t.Fatalf("err: %v", err) } if !bytes.Equal(recv, []byte("ping")) { t.Fatalf("bad: %v", recv) } if _, err := conn.Write([]byte("pong")); err != nil { t.Fatalf("err: %v", err) } // Check the remote addr addr := conn.RemoteAddr().(*net.TCPAddr) if addr.IP.String() != "ffff::ffff" { t.Fatalf("bad: %v", addr) } if addr.Port != 1000 { t.Fatalf("bad: %v", addr) } h := conn.(*Conn).ProxyHeader() if !h.EqualsTo(header) { t.Errorf("bad: %v", h) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestAcceptReturnsErrorWhenPolicyFuncErrors(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } expectedErr := fmt.Errorf("failure") policyFunc := func(_ net.Addr) (Policy, error) { return USE, expectedErr } pl := &Listener{Listener: l, Policy: policyFunc} cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) close(cliResult) }() conn, err := pl.Accept() if err != expectedErr { t.Fatalf("Expected error %v, got %v", expectedErr, err) } if conn != nil { t.Fatalf("Expected no connection, got %v", conn) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestPanicIfPolicyAndConnPolicySet(t *testing.T) { l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } connPolicyFunc := func(_ ConnPolicyOptions) (Policy, error) { return USE, nil } policyFunc := func(_ net.Addr) (Policy, error) { return USE, nil } pl := &Listener{Listener: l, ConnPolicy: connPolicyFunc, Policy: policyFunc} cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) close(cliResult) }() defer func() { if r := recover(); r != nil { fmt.Printf("accept did panic as expected with error, %v", r) } }() conn, err := pl.Accept() if err != nil { t.Fatalf("expected the accept to panic but did not and error is returned, got %v", err) } if conn != nil { t.Fatalf("expected the accept to panic but did not, got %v", conn) } t.Fatalf("expected the accept to panic but did not") } func TestAcceptReturnsErrorWhenConnPolicyFuncErrors(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } expectedErr := fmt.Errorf("failure") connPolicyFunc := func(_ ConnPolicyOptions) (Policy, error) { return USE, expectedErr } pl := &Listener{Listener: l, ConnPolicy: connPolicyFunc} cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) close(cliResult) }() conn, err := pl.Accept() if err != expectedErr { t.Fatalf("Expected error %v, got %v", expectedErr, err) } if conn != nil { t.Fatalf("Expected no connection, got %v", conn) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestReadingIsRefusedWhenProxyHeaderRequiredButMissing(t *testing.T) { l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } policyFunc := func(_ net.Addr) (Policy, error) { return REQUIRE, nil } pl := &Listener{Listener: l, Policy: policyFunc} cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) recv := make([]byte, 4) if _, err = conn.Read(recv); err != ErrNoProxyProtocol { t.Fatalf("Expected error %v, received %v", ErrNoProxyProtocol, err) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestReadingIsRefusedWhenProxyHeaderPresentButNotAllowed(t *testing.T) { l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } policyFunc := func(_ net.Addr) (Policy, error) { return REJECT, nil } pl := &Listener{Listener: l, Policy: policyFunc} cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, } if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) recv := make([]byte, 4) if _, err = conn.Read(recv); err != ErrSuperfluousProxyHeader { t.Fatalf("Expected error %v, received %v", ErrSuperfluousProxyHeader, err) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestIgnorePolicyIgnoresIpFromProxyHeader(t *testing.T) { l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } policyFunc := func(_ net.Addr) (Policy, error) { return IGNORE, nil } pl := &Listener{Listener: l, Policy: policyFunc} cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) // Write out the header! header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, } if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { cliResult <- err return } if !bytes.Equal(recv, []byte("pong")) { cliResult <- fmt.Errorf("bad: %v", recv) return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { t.Fatalf("err: %v", err) } if !bytes.Equal(recv, []byte("ping")) { t.Fatalf("bad: %v", recv) } if _, err := conn.Write([]byte("pong")); err != nil { t.Fatalf("err: %v", err) } // Check the remote addr addr := conn.RemoteAddr().(*net.TCPAddr) if addr.IP.String() != "127.0.0.1" { t.Fatalf("bad: %v", addr) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func Test_AllOptionsAreRecognized(t *testing.T) { recognizedOpt1 := false opt1 := func(_ *Conn) { recognizedOpt1 = true } recognizedOpt2 := false opt2 := func(_ *Conn) { recognizedOpt2 = true } server, client := net.Pipe() t.Cleanup(func() { if closeErr := client.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) c := NewConn(server, opt1, opt2) if !recognizedOpt1 { t.Error("Expected option 1 recognized") } if !recognizedOpt2 { t.Error("Expected option 2 recognized") } t.Cleanup(func() { if closeErr := c.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) } func TestReadingIsRefusedOnErrorWhenRemoteAddrRequestedFirst(t *testing.T) { l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } policyFunc := func(_ net.Addr) (Policy, error) { return REQUIRE, nil } pl := &Listener{Listener: l, Policy: policyFunc} cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) _ = conn.RemoteAddr() recv := make([]byte, 4) if _, err = conn.Read(recv); err != ErrNoProxyProtocol { t.Fatalf("Expected error %v, received %v", ErrNoProxyProtocol, err) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestReadingIsRefusedOnErrorWhenLocalAddrRequestedFirst(t *testing.T) { l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } policyFunc := func(_ net.Addr) (Policy, error) { return REQUIRE, nil } pl := &Listener{Listener: l, Policy: policyFunc} cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) _ = conn.LocalAddr() recv := make([]byte, 4) if _, err = conn.Read(recv); err != ErrNoProxyProtocol { t.Fatalf("Expected error %v, received %v", ErrNoProxyProtocol, err) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestSkipProxyProtocolPolicy(t *testing.T) { l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } connPolicyFunc := func(_ ConnPolicyOptions) (Policy, error) { return SKIP, nil } pl := &Listener{ Listener: l, ConnPolicy: connPolicyFunc, } cliResult := make(chan error) ping := []byte("ping") go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) if _, err := conn.Write(ping); err != nil { cliResult <- err return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) _, ok := conn.(*net.TCPConn) if !ok { t.Fatal("err: should be a tcp connection") } _ = conn.LocalAddr() recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { t.Fatalf("Unexpected read error: %v", err) } if !bytes.Equal(ping, recv) { t.Fatalf("Unexpected %s data while expected %s", recv, ping) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestSkipProxyProtocolConnPolicy(t *testing.T) { l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } policyFunc := func(_ net.Addr) (Policy, error) { return SKIP, nil } pl := &Listener{ Listener: l, Policy: policyFunc, } cliResult := make(chan error) ping := []byte("ping") go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) if _, err := conn.Write(ping); err != nil { cliResult <- err return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) _, ok := conn.(*net.TCPConn) if !ok { t.Fatal("err: should be a tcp connection") } _ = conn.LocalAddr() recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { t.Fatalf("Unexpected read error: %v", err) } if !bytes.Equal(ping, recv) { t.Fatalf("Unexpected %s data while expected %s", recv, ping) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestLocalCommandUsesUnderlyingAddrs(t *testing.T) { l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } pl := &Listener{Listener: l} header := &Header{ Version: 2, Command: LOCAL, TransportProtocol: UNSPEC, } cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } // Write a LOCAL header with no address information. if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } // Close client side to avoid leaving the connection open. if err := conn.Close(); err != nil { cliResult <- err return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) proxyConn := conn.(*Conn) // LOCAL should make LocalAddr/RemoteAddr fall back to underlying addresses. if proxyConn.LocalAddr().String() != proxyConn.Raw().LocalAddr().String() { t.Fatalf("LocalAddr should use underlying address for LOCAL command") } if proxyConn.RemoteAddr().String() != proxyConn.Raw().RemoteAddr().String() { t.Fatalf("RemoteAddr should use underlying address for LOCAL command") } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func Test_ConnectionCasts(t *testing.T) { l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } policyFunc := func(_ net.Addr) (Policy, error) { return REQUIRE, nil } pl := &Listener{Listener: l, Policy: policyFunc} cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) proxyprotoConn := conn.(*Conn) _, ok := proxyprotoConn.TCPConn() if !ok { t.Fatal("err: should be a tcp connection") } _, ok = proxyprotoConn.UDPConn() if ok { t.Fatal("err: should be a tcp connection not udp") } _, ok = proxyprotoConn.UnixConn() if ok { t.Fatal("err: should be a tcp connection not unix") } _, ok = proxyprotoConn.Raw().(*net.TCPConn) if !ok { t.Fatal("err: should be a tcp connection") } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func Test_ConnectionErrorsWhenHeaderValidationFails(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } validationError := fmt.Errorf("failed to validate") pl := &Listener{Listener: l, ValidateHeader: func(*Header) error { return validationError }} cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) // Write out the header! header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, } if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) recv := make([]byte, 4) if _, err = conn.Read(recv); err != validationError { t.Fatalf("expected validation error, got %v", err) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func Test_ConnectionHandlesInvalidUpstreamError(t *testing.T) { l, err := net.Listen("tcp", "localhost:8080") if err != nil { t.Fatalf("error creating listener: %v", err) } var connectionCounter atomic.Int32 newLn := &Listener{ Listener: l, ConnPolicy: func(_ ConnPolicyOptions) (Policy, error) { // Return the invalid upstream error on the first call, the listener // should remain open and accepting. times := connectionCounter.Load() if times == 0 { connectionCounter.Store(times + 1) return REJECT, ErrInvalidUpstream } return REJECT, ErrNoProxyProtocol }, } // Kick off the listener and return any error via the chanel. errCh := make(chan error) defer close(errCh) go func() { _, err := newLn.Accept() errCh <- err }() client := http.Client{Timeout: 200 * time.Millisecond} // Make two calls to trigger the listener's accept, the first should experience // the ErrInvalidUpstream and keep the listener open, the second should experience // a different error which will cause the listener to close. // First call should experience the ErrInvalidUpstream and keep the listener open. resp, err := client.Get("http://localhost:8080") if resp != nil { if closeErr := resp.Body.Close(); closeErr != nil { t.Fatalf("failed to close response body: %v", closeErr) } } if err != nil && !errors.Is(err, io.EOF) { t.Logf("first request failed as expected: %v", err) } // Ensure the ConnPolicy function was called at least once. deadline := time.Now().Add(2 * time.Second) for time.Now().Before(deadline) { if connectionCounter.Load() >= 1 { break } } if connectionCounter.Load() < 1 { t.Fatalf("expected ConnPolicy to be called at least once") } // Wait a few seconds to ensure we didn't get anything back on our channel. select { case err := <-errCh: if err != nil { t.Fatalf("invalid upstream shouldn't return an error: %v", err) } case <-time.After(2 * time.Second): // No error returned (as expected, we're still listening though) } // Second call should experience a different error and cause the listener to close. resp, err = client.Get("http://localhost:8080") if resp != nil { if closeErr := resp.Body.Close(); closeErr != nil { t.Fatalf("failed to close response body: %v", closeErr) } } if err != nil && !errors.Is(err, io.EOF) { t.Logf("second request failed as expected: %v", err) } // Ensure the listener is closed. select { case err := <-errCh: if err != nil && !errors.Is(err, ErrNoProxyProtocol) { t.Fatalf("unexpected error type: %v", err) } case <-time.After(2 * time.Second): t.Fatalf("timed out waiting for listener") } } type TestTLSServer struct { Listener net.Listener // TLS is the optional TLS configuration, populated with a new config // after TLS is started. If set on an unstarted server before StartTLS // is called, existing fields are copied into the new config. TLS *tls.Config TLSClientConfig *tls.Config // certificate is a parsed version of the TLS config certificate, if present. certificate *x509.Certificate } func (s *TestTLSServer) Addr() string { return s.Listener.Addr().String() } func (s *TestTLSServer) Close() error { return s.Listener.Close() } // based on net/http/httptest/Server.StartTLS. func NewTestTLSServer(l net.Listener) *TestTLSServer { s := &TestTLSServer{} cert, err := tls.X509KeyPair(LocalhostCert, LocalhostKey) if err != nil { panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) } s.TLS = new(tls.Config) if len(s.TLS.Certificates) == 0 { s.TLS.Certificates = []tls.Certificate{cert} } s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0]) if err != nil { panic(fmt.Sprintf("NewTestTLSServer: %v", err)) } certpool := x509.NewCertPool() certpool.AddCert(s.certificate) s.TLSClientConfig = &tls.Config{ RootCAs: certpool, MinVersion: tls.VersionTLS12, } s.Listener = tls.NewListener(l, s.TLS) return s } func Test_TLSServer(t *testing.T) { l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } s := NewTestTLSServer(l) s.Listener = &Listener{ Listener: s.Listener, Policy: func(_ net.Addr) (Policy, error) { return REQUIRE, nil }, } defer func() { if err := s.Close(); err != nil { t.Errorf("failed to close TLS server: %v", err) } }() cliResult := make(chan error) go func() { conn, err := tls.Dial("tcp", s.Addr(), s.TLSClientConfig) if err != nil { cliResult <- err return } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) // Write out the header! header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, } if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } if _, err := conn.Write([]byte("test")); err != nil { cliResult <- err return } close(cliResult) }() conn, err := s.Listener.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) recv := make([]byte, 1024) n, err := conn.Read(recv) if err != nil { t.Fatalf("expected no error, got %v", err) } if string(recv[:n]) != "test" { t.Fatalf("expected \"test\", got \"%s\" %v", recv[:n], recv[:n]) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func Test_MisconfiguredTLSServerRespondsWithUnderlyingError(t *testing.T) { l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } s := NewTestTLSServer(l) s.Listener = &Listener{ Listener: s.Listener, Policy: func(_ net.Addr) (Policy, error) { return REQUIRE, nil }, } defer func() { if err := s.Close(); err != nil { t.Errorf("failed to close TLS server: %v", err) } }() cliResult := make(chan error) go func() { // this is not a valid TLS connection, we are // connecting to the TLS endpoint via plain TCP. // // it's an example of a configuration error: // client: HTTP -> PROXY // server: PROXY -> TLS -> HTTP // // we want to bubble up the underlying error, // in this case a tls handshake error, instead // of responding with a non-descript // > "Proxy protocol signature not present". conn, err := net.Dial("tcp", s.Addr()) if err != nil { cliResult <- err return } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) // Write out the header! header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, } if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } if _, err := conn.Write([]byte("GET /foo/bar HTTP/1.1")); err != nil { cliResult <- err return } close(cliResult) }() conn, err := s.Listener.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) recv := make([]byte, 1024) if _, err = conn.Read(recv); err.Error() != "tls: first record does not look like a TLS handshake" { t.Fatalf("expected tls handshake error, got %s", err) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } type testConn struct { readFromCalledWith io.Reader reads int net.Conn // nil; crash on any unexpected use } type deadlineConn struct { deadline time.Time readDeadline time.Time writeDeadline time.Time } func (c *deadlineConn) Read(_ []byte) (int, error) { return 0, io.EOF } func (c *deadlineConn) Write(p []byte) (int, error) { return len(p), nil } func (c *deadlineConn) Close() error { return nil } func (c *deadlineConn) LocalAddr() net.Addr { return dummyAddr("local") } func (c *deadlineConn) RemoteAddr() net.Addr { return dummyAddr("remote") } func (c *deadlineConn) SetDeadline(t time.Time) error { c.deadline = t return nil } func (c *deadlineConn) SetReadDeadline(t time.Time) error { c.readDeadline = t return nil } func (c *deadlineConn) SetWriteDeadline(t time.Time) error { c.writeDeadline = t return nil } type noReadFromConn struct { written bytes.Buffer } func (c *noReadFromConn) Read(_ []byte) (int, error) { return 0, io.EOF } func (c *noReadFromConn) Write(p []byte) (int, error) { return c.written.Write(p) } func (c *noReadFromConn) Close() error { return nil } func (c *noReadFromConn) LocalAddr() net.Addr { return dummyAddr("local") } func (c *noReadFromConn) RemoteAddr() net.Addr { return dummyAddr("remote") } func (c *noReadFromConn) SetDeadline(time.Time) error { return nil } func (c *noReadFromConn) SetReadDeadline(time.Time) error { return nil } func (c *noReadFromConn) SetWriteDeadline(time.Time) error { return nil } type dummyAddr string func (a dummyAddr) Network() string { return "dummy" } func (a dummyAddr) String() string { return string(a) } func (c *testConn) ReadFrom(r io.Reader) (int64, error) { c.readFromCalledWith = r b, err := io.ReadAll(r) return int64(len(b)), err } func (c *testConn) Write(p []byte) (int, error) { return len(p), nil } func (c *testConn) Read(_ []byte) (int, error) { if c.reads == 0 { return 0, io.EOF } c.reads-- return 1, nil } func TestCopyToWrappedConnection(t *testing.T) { innerConn := &testConn{} wrappedConn := NewConn(innerConn) dummySrc := &testConn{reads: 1} if _, err := io.Copy(wrappedConn, dummySrc); err != nil { t.Fatalf("err: %v", err) } if innerConn.readFromCalledWith != dummySrc { t.Error("Expected io.Copy to delegate to ReadFrom function of inner destination connection") } } func TestCopyFromWrappedConnection(t *testing.T) { wrappedConn := NewConn(&testConn{reads: 1}) dummyDst := &testConn{} if _, err := io.Copy(dummyDst, wrappedConn); err != nil { t.Fatalf("err: %v", err) } if dummyDst.readFromCalledWith != wrappedConn.conn { t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom method of destination") } } func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) { innerConn1 := &testConn{reads: 1} wrappedConn1 := NewConn(innerConn1) innerConn2 := &testConn{} wrappedConn2 := NewConn(innerConn2) if _, err := io.Copy(wrappedConn1, wrappedConn2); err != nil { t.Fatalf("err: %v", err) } if innerConn1.readFromCalledWith != innerConn2 { t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom of inner destination connection") } } func TestDeadlineWrappersDelegate(t *testing.T) { conn := &deadlineConn{} proxyConn := NewConn(conn) deadline := time.Now().Add(2 * time.Second) readDeadline := time.Now().Add(3 * time.Second) writeDeadline := time.Now().Add(4 * time.Second) // Ensure deadline setters pass through to the underlying connection. if err := proxyConn.SetDeadline(deadline); err != nil { t.Fatalf("unexpected SetDeadline error: %v", err) } if err := proxyConn.SetReadDeadline(readDeadline); err != nil { t.Fatalf("unexpected SetReadDeadline error: %v", err) } if err := proxyConn.SetWriteDeadline(writeDeadline); err != nil { t.Fatalf("unexpected SetWriteDeadline error: %v", err) } if !conn.deadline.Equal(deadline) { t.Fatalf("SetDeadline did not pass through value") } if !conn.readDeadline.Equal(readDeadline) { t.Fatalf("SetReadDeadline did not pass through value") } if !conn.writeDeadline.Equal(writeDeadline) { t.Fatalf("SetWriteDeadline did not pass through value") } } func TestReadFromFallbackCopiesToConn(t *testing.T) { conn := &noReadFromConn{} proxyConn := NewConn(conn) payload := []byte("payload") if _, err := proxyConn.ReadFrom(bytes.NewReader(payload)); err != nil { t.Fatalf("unexpected ReadFrom error: %v", err) } // When the inner connection does not implement io.ReaderFrom, // ReadFrom should fall back to io.Copy and write the payload. if !bytes.Equal(conn.written.Bytes(), payload) { t.Fatalf("unexpected write content: %q", conn.written.String()) } } func TestWriteToDrainsBufferedData(t *testing.T) { l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { t.Fatalf("err: %v", err) } pl := &Listener{Listener: l} header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, } payload := []byte("ping") cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } // Write the header followed by payload to populate the reader buffer. if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } if _, err := conn.Write(payload); err != nil { cliResult <- err return } // Close the client so WriteTo's io.Copy completes. if err := conn.Close(); err != nil { cliResult <- err return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } t.Cleanup(func() { if closeErr := conn.Close(); closeErr != nil { t.Errorf("failed to close connection: %v", closeErr) } }) var out bytes.Buffer if _, err := conn.(*Conn).WriteTo(&out); err != nil { t.Fatalf("unexpected WriteTo error: %v", err) } if !bytes.Equal(out.Bytes(), payload) { t.Fatalf("unexpected WriteTo output: %q", out.String()) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } // chunkedConn wraps a net.Conn and limits reads to simulate TCP chunking. type chunkedConn struct { net.Conn maxRead int readCalls int bytesRead int } func (c *chunkedConn) Read(b []byte) (int, error) { if len(b) > c.maxRead { b = b[:c.maxRead] } n, err := c.Conn.Read(b) if n > 0 { c.readCalls++ c.bytesRead += n } return n, err } // TestConnReadHandlesChunkedPayload verifies Conn.Read does not drop data // when the initial TCP read is smaller than the payload. func TestConnReadHandlesChunkedPayload(t *testing.T) { const payloadSize = 400 proxyHeader := []byte("PROXY TCP4 192.168.1.1 192.168.1.2 12345 443\r\n") payload := bytes.Repeat([]byte("X"), payloadSize) fullData := append(proxyHeader, payload...) serverConn, clientConn := net.Pipe() defer func() { serverCloseErr := serverConn.Close() clientCloseErr := clientConn.Close() if serverCloseErr != nil || clientCloseErr != nil { t.Errorf("failed to close connection: %v, %v", serverCloseErr, clientCloseErr) } }() go func() { _, _ = clientConn.Write(fullData) _ = clientConn.Close() }() // Simulate TCP delivering only 256 bytes in first read. chunked := &chunkedConn{Conn: serverConn, maxRead: 256} // Create a ProxyProto-wrapped connection. conn := NewConn(chunked) buf := make([]byte, 64) readPayload := make([]byte, 0, payloadSize) for len(readPayload) < payloadSize { _ = conn.SetReadDeadline(time.Now().Add(time.Second)) n, err := conn.Read(buf) if err != nil && err != io.EOF { t.Fatalf("unexpected read error: %v", err) } if n > 0 { readPayload = append(readPayload, buf[:n]...) } if err == io.EOF { break } } t.Logf("Sent: %d bytes payload (after %d byte PROXY header)", payloadSize, len(proxyHeader)) t.Logf("Read: %d bytes", len(readPayload)) if len(readPayload) != payloadSize { t.Fatalf("read %d bytes, expected %d", len(readPayload), payloadSize) } if !bytes.Equal(readPayload, payload) { t.Fatalf("payload mismatch") } // Ensure the proxy connection read from the underlying conn // and drained all bytes, not just buffered reads. if chunked.readCalls == 0 { t.Fatalf("expected underlying reads to occur") } if chunked.bytesRead <= len(proxyHeader) { t.Fatalf("expected reads beyond header, got %d bytes", chunked.bytesRead) } if chunked.bytesRead != len(fullData) { t.Fatalf("underlying reads=%d bytes, expected %d", chunked.bytesRead, len(fullData)) } } func TestReadUsesConnWhenBufReaderNil(t *testing.T) { serverConn, clientConn := net.Pipe() t.Cleanup(func() { if closeErr := serverConn.Close(); closeErr != nil { t.Errorf("failed to close server connection: %v", closeErr) } }) t.Cleanup(func() { if closeErr := clientConn.Close(); closeErr != nil { t.Errorf("failed to close client connection: %v", closeErr) } }) proxyConn := NewConn(serverConn) sendSecond := make(chan struct{}) go func() { _, _ = clientConn.Write([]byte("a")) <-sendSecond _, _ = clientConn.Write([]byte("b")) _ = clientConn.Close() }() buf := make([]byte, 1) // First read processes header detection and drains the buffer. if _, err := proxyConn.Read(buf); err != nil { t.Fatalf("first read failed: %v", err) } if proxyConn.bufReader != nil { t.Fatalf("expected bufReader to be nil after draining buffer") } // With bufReader cleared, Read should use the underlying conn. close(sendSecond) if _, err := proxyConn.Read(buf); err != nil { t.Fatalf("second read failed: %v", err) } if string(buf) != "b" { t.Fatalf("unexpected second read payload: %q", string(buf)) } } func TestWriteToUsesConnWhenBufReaderNil(t *testing.T) { serverConn, clientConn := net.Pipe() t.Cleanup(func() { if closeErr := serverConn.Close(); closeErr != nil { t.Errorf("failed to close server connection: %v", closeErr) } }) t.Cleanup(func() { if closeErr := clientConn.Close(); closeErr != nil { t.Errorf("failed to close client connection: %v", closeErr) } }) proxyConn := NewConn(serverConn) sendPayload := make(chan struct{}) go func() { _, _ = clientConn.Write([]byte("x")) <-sendPayload _, _ = clientConn.Write([]byte("payload")) _ = clientConn.Close() }() // Process header detection and drain the buffer. buf := make([]byte, 1) if _, err := proxyConn.Read(buf); err != nil { t.Fatalf("initial read failed: %v", err) } if proxyConn.bufReader != nil { t.Fatalf("expected bufReader to be nil after draining buffer") } // With bufReader cleared, WriteTo should copy directly from conn. close(sendPayload) var out bytes.Buffer if _, err := proxyConn.WriteTo(&out); err != nil { t.Fatalf("WriteTo failed: %v", err) } if out.String() != "payload" { t.Fatalf("unexpected WriteTo output: %q", out.String()) } } func benchmarkTCPProxy(size int, b *testing.B) { // create and start the echo backend backend, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { b.Fatalf("err: %v", err) } b.Cleanup(func() { if closeErr := backend.Close(); closeErr != nil { b.Errorf("failed to close backend: %v", closeErr) } }) go func() { for { conn, err := backend.Accept() if err != nil { break } _, err = io.Copy(conn, conn) if err != nil { b.Errorf("Failed to read entire payload: %v", err) return } // Can't defer since we keep accepting on each for iteration. if closeErr := conn.Close(); closeErr != nil { b.Errorf("failed to close connection: %v", closeErr) return } } }() // start the proxyprotocol enabled tcp proxy l, err := net.Listen("tcp", testLocalhostRandomPort) if err != nil { b.Fatalf("err: %v", err) } b.Cleanup(func() { if closeErr := l.Close(); closeErr != nil { b.Errorf("failed to close listener: %v", closeErr) } }) pl := &Listener{Listener: l} go func() { for { conn, err := pl.Accept() if err != nil { break } bConn, err := net.Dial("tcp", backend.Addr().String()) if err != nil { b.Errorf("failed to dial backend: %v", err) return } go func() { _, err = io.Copy(bConn, conn) if err != nil { b.Errorf("Failed to proxy incoming data to backend: %v", err) return } if closeErr := bConn.(*net.TCPConn).CloseWrite(); closeErr != nil { b.Errorf("failed to close write: %v", closeErr) return } }() _, err = io.Copy(conn, bConn) if err != nil { panic(fmt.Sprintf("Failed to proxy data from backend: %v", err)) } if closeErr := conn.Close(); closeErr != nil { b.Errorf("failed to close connection: %v", closeErr) return } if closeErr := bConn.Close(); closeErr != nil { b.Errorf("failed to close connection: %v", closeErr) return } } }() data := make([]byte, size) header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP(testSourceIPv4Addr), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP(testDestinationIPv4Addr), Port: 2000, }, } // now for the actual benchmark b.ResetTimer() for n := 0; n < b.N; n++ { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { b.Fatalf("err: %v", err) } // Write out the header! if _, err := header.WriteTo(conn); err != nil { b.Fatalf("err: %v", err) } // send data go func() { _, err = conn.Write(data) if err != nil { b.Errorf("Failed to write data: %v", err) return } if closeErr := conn.(*net.TCPConn).CloseWrite(); closeErr != nil { b.Errorf("failed to close write: %v", closeErr) return } }() // receive data n, err := io.Copy(io.Discard, conn) if n != int64(len(data)) { b.Fatalf("Expected to receive %d bytes, got %d", len(data), n) } if err != nil { b.Fatalf("Failed to read data: %v", err) } if closeErr := conn.Close(); closeErr != nil { b.Errorf("failed to close connection: %v", closeErr) return } } } func BenchmarkTCPProxy16KB(b *testing.B) { benchmarkTCPProxy(16*1024, b) } func BenchmarkTCPProxy32KB(b *testing.B) { benchmarkTCPProxy(32*1024, b) } func BenchmarkTCPProxy64KB(b *testing.B) { benchmarkTCPProxy(64*1024, b) } func BenchmarkTCPProxy128KB(b *testing.B) { benchmarkTCPProxy(128*1024, b) } func BenchmarkTCPProxy256KB(b *testing.B) { benchmarkTCPProxy(256*1024, b) } func BenchmarkTCPProxy512KB(b *testing.B) { benchmarkTCPProxy(512*1024, b) } func BenchmarkTCPProxy1024KB(b *testing.B) { benchmarkTCPProxy(1024*1024, b) } func BenchmarkTCPProxy2048KB(b *testing.B) { benchmarkTCPProxy(2048*1024, b) } // copied from src/net/http/internal/testcert.go. // Copyright 2015 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // LocalhostCert is a PEM-encoded TLS cert with SAN IPs "127.0.0.1" and "[::1]", // expiring at Jan 29 16:00:00 2084 GMT. Generated from src/crypto/tls: // go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h. var LocalhostCert = []byte(`-----BEGIN CERTIFICATE----- MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4 iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9 tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM fblo6RBxUQ== -----END CERTIFICATE-----`) // LocalhostKey is the private key for localhostCert. var LocalhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9 SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB AoGAGRzwwir7XvBOAy5tM/uV6e+Zf6anZzus1s1Y1ClbjbE6HXbnWWF/wbZGOpet 3Zm4vD6MXc7jpTLryzTQIvVdfQbRc6+MUVeLKwZatTXtdZrhu+Jk7hx0nTPy8Jcb uJqFk541aEw+mMogY/xEcfbWd6IOkp+4xqjlFLBEDytgbIECQQDvH/E6nk+hgN4H qzzVtxxr397vWrjrIgPbJpQvBsafG7b0dA4AFjwVbFLmQcj2PprIMmPcQrooz8vp jy4SHEg1AkEA/v13/5M47K9vCxmb8QeD/asydfsgS5TeuNi8DoUBEmiSJwma7FXY fFUtxuvL7XvjwjN5B30pNEbc6Iuyt7y4MQJBAIt21su4b3sjXNueLKH85Q+phy2U fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA== -----END RSA PRIVATE KEY-----`) pires-go-proxyproto-04c9ad1/tlv.go000066400000000000000000000102161514137054000172450ustar00rootroot00000000000000// Type-Length-Value splitting and parsing for proxy protocol V2 // See spec https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt sections 2.2 to 2.7 and package proxyproto import ( "encoding/binary" "errors" "fmt" "math" ) // TLV type constants defined by the PROXY protocol spec. // //nolint:revive // Names follow the spec. const ( // Section 2.2. PP2_TYPE_ALPN PP2Type = 0x01 PP2_TYPE_AUTHORITY PP2Type = 0x02 PP2_TYPE_CRC32C PP2Type = 0x03 PP2_TYPE_NOOP PP2Type = 0x04 PP2_TYPE_UNIQUE_ID PP2Type = 0x05 PP2_TYPE_SSL PP2Type = 0x20 PP2_SUBTYPE_SSL_VERSION PP2Type = 0x21 PP2_SUBTYPE_SSL_CN PP2Type = 0x22 PP2_SUBTYPE_SSL_CIPHER PP2Type = 0x23 PP2_SUBTYPE_SSL_SIG_ALG PP2Type = 0x24 PP2_SUBTYPE_SSL_KEY_ALG PP2Type = 0x25 PP2_SUBTYPE_SSL_GROUP PP2Type = 0x26 PP2_SUBTYPE_SSL_SIG_SCHEME PP2Type = 0x27 PP2_SUBTYPE_SSL_CLIENT_CERT PP2Type = 0x28 PP2_TYPE_NETNS PP2Type = 0x30 // Section 2.2.7, reserved types. PP2_TYPE_MIN_CUSTOM PP2Type = 0xE0 PP2_TYPE_MAX_CUSTOM PP2Type = 0xEF PP2_TYPE_MIN_EXPERIMENT PP2Type = 0xF0 PP2_TYPE_MAX_EXPERIMENT PP2Type = 0xF7 PP2_TYPE_MIN_FUTURE PP2Type = 0xF8 PP2_TYPE_MAX_FUTURE PP2Type = 0xFF ) var ( // ErrTruncatedTLV indicates a TLV was truncated. ErrTruncatedTLV = errors.New("proxyproto: truncated TLV") // ErrMalformedTLV indicates a TLV has malformed data. ErrMalformedTLV = errors.New("proxyproto: malformed TLV Value") // ErrIncompatibleTLV indicates a TLV is of an unexpected type. ErrIncompatibleTLV = errors.New("proxyproto: incompatible TLV type") ) // PP2Type is the proxy protocol v2 type. type PP2Type byte // TLV is a uninterpreted Type-Length-Value for V2 protocol, see section 2.2. type TLV struct { Type PP2Type Value []byte } // SplitTLVs splits the Type-Length-Value vector, returns the vector or an error. func SplitTLVs(raw []byte) ([]TLV, error) { var tlvs []TLV for i := 0; i < len(raw); { tlv := TLV{ Type: PP2Type(raw[i]), } if len(raw)-i <= 2 { return nil, ErrTruncatedTLV } tlvLen := int(binary.BigEndian.Uint16(raw[i+1 : i+3])) // Max length = 65K i += 3 if i+tlvLen > len(raw) { return nil, ErrTruncatedTLV } // Ignore no-op padding if tlv.Type != PP2_TYPE_NOOP { tlv.Value = make([]byte, tlvLen) copy(tlv.Value, raw[i:i+tlvLen]) } i += tlvLen tlvs = append(tlvs, tlv) } return tlvs, nil } // JoinTLVs joins multiple Type-Length-Value records. func JoinTLVs(tlvs []TLV) ([]byte, error) { var raw []byte for _, tlv := range tlvs { if len(tlv.Value) > math.MaxUint16 { return nil, fmt.Errorf("proxyproto: cannot format TLV %v with length %d", tlv.Type, len(tlv.Value)) } var length [2]byte //nolint:gosec // lengthValue is validated above. lengthValue := uint16(len(tlv.Value)) binary.BigEndian.PutUint16(length[:], lengthValue) raw = append(raw, byte(tlv.Type)) raw = append(raw, length[:]...) raw = append(raw, tlv.Value...) } return raw, nil } // Registered is true if the type is registered in the spec, see section 2.2. func (p PP2Type) Registered() bool { switch p { case PP2_TYPE_ALPN, PP2_TYPE_AUTHORITY, PP2_TYPE_CRC32C, PP2_TYPE_NOOP, PP2_TYPE_UNIQUE_ID, PP2_TYPE_SSL, PP2_SUBTYPE_SSL_VERSION, PP2_SUBTYPE_SSL_CN, PP2_SUBTYPE_SSL_CIPHER, PP2_SUBTYPE_SSL_SIG_ALG, PP2_SUBTYPE_SSL_KEY_ALG, PP2_TYPE_NETNS: return true } return false } // App is true if the type is reserved for application specific data, see section 2.2.7. func (p PP2Type) App() bool { return p >= PP2_TYPE_MIN_CUSTOM && p <= PP2_TYPE_MAX_CUSTOM } // Experiment is true if the type is reserved for temporary experimental use by application // developers, see section 2.2.7. func (p PP2Type) Experiment() bool { return p >= PP2_TYPE_MIN_EXPERIMENT && p <= PP2_TYPE_MAX_EXPERIMENT } // Future is true is the type is reserved for future use, see section 2.2.7. func (p PP2Type) Future() bool { return p >= PP2_TYPE_MIN_FUTURE } // Spec is true if the type is covered by the spec, see section 2.2 and 2.2.7. func (p PP2Type) Spec() bool { return p.Registered() || p.App() || p.Experiment() || p.Future() } pires-go-proxyproto-04c9ad1/tlv_test.go000066400000000000000000000102241514137054000203030ustar00rootroot00000000000000package proxyproto import ( "bufio" "bytes" "testing" ) var ( fixtureOneByteTLV = []byte{byte(PP2_TYPE_MIN_CUSTOM) + 1} fixtureTwoByteTLV = []byte{byte(PP2_TYPE_MIN_CUSTOM) + 2, 0x00} fixtureEmptyLenTLV = []byte{byte(PP2_TYPE_MIN_CUSTOM) + 3, 0x00, 0x01} fixturePartialLenTLV = []byte{byte(PP2_TYPE_MIN_CUSTOM) + 3, 0x00, 0x02, 0x00} ) var invalidTLVTests = []struct { name string reader *bufio.Reader expectedError error }{ { name: "One byte TLV", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixtureOneByteTLV)...)), expectedError: ErrTruncatedTLV, }, { name: "Two byte TLV", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixtureTwoByteTLV)...)), expectedError: ErrTruncatedTLV, }, { name: "Empty Len TLV", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixtureEmptyLenTLV)...)), expectedError: ErrTruncatedTLV, }, { name: "Partial Len TLV", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixturePartialLenTLV)...)), expectedError: ErrTruncatedTLV, }, } func TestValid0Length(t *testing.T) { r := bufio.NewReader(bytes.NewReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, []byte{byte(PP2_TYPE_MIN_CUSTOM), 0x00, 0x00})...))) h, err := Read(r) if err != nil { t.Fatalf("unexpected error: %v", err) } tlvs, err := h.TLVs() if err != nil { t.Fatalf("unexpected error: %v", err) } if len(tlvs) != 1 { t.Fatalf("expected 1 tlv, got %d", len(tlvs)) } if len(tlvs[0].Value) != 0 { t.Fatalf("expected 0 byte tlv value, got %x", tlvs[0].Value) } } func TestInvalidV2TLV(t *testing.T) { for _, tc := range invalidTLVTests { t.Run(tc.name, func(t *testing.T) { if hdr, err := Read(tc.reader); err != nil { t.Fatalf("TestInvalidV2TLV %s: unexpected error reading proxy protocol %#v", tc.name, err) } else if _, err := hdr.TLVs(); err != tc.expectedError { t.Fatalf("TestInvalidV2TLV %s: expected %#v, actual %#v", tc.name, tc.expectedError, err) } }) } } func TestV2TLVPP2Registered(t *testing.T) { pp2RegTypes := []PP2Type{ PP2_TYPE_ALPN, PP2_TYPE_AUTHORITY, PP2_TYPE_CRC32C, PP2_TYPE_NOOP, PP2_TYPE_UNIQUE_ID, PP2_TYPE_SSL, PP2_SUBTYPE_SSL_VERSION, PP2_SUBTYPE_SSL_CN, PP2_SUBTYPE_SSL_CIPHER, PP2_SUBTYPE_SSL_SIG_ALG, PP2_SUBTYPE_SSL_KEY_ALG, PP2_TYPE_NETNS, } pp2RegMap := make(map[PP2Type]bool) for _, p := range pp2RegTypes { pp2RegMap[p] = true if !p.Registered() { t.Fatalf("TestV2TLVPP2Registered: type %x should be registered", p) } if !p.Spec() { t.Fatalf("TestV2TLVPP2Registered: type %x should be in spec", p) } if p.App() { t.Fatalf("TestV2TLVPP2Registered: type %x unexpectedly app", p) } if p.Experiment() { t.Fatalf("TestV2TLVPP2Registered: type %x unexpectedly experiment", p) } if p.Future() { t.Fatalf("TestV2TLVPP2Registered: type %x unexpectedly future", p) } } lastType := PP2Type(0xFF) for i := range int(lastType) { p := PP2Type(i) if !pp2RegMap[p] && p.Registered() { t.Fatalf("TestV2TLVPP2Registered: type %x unexpectedly registered", p) } } if lastType.Registered() { t.Fatalf("TestV2TLVPP2Registered: type %x unexpectedly registered", lastType) } } func TestJoinTLVs(t *testing.T) { tests := []struct { name string raw []byte tlvs []TLV }{ { name: "authority TLV", raw: append([]byte{byte(PP2_TYPE_AUTHORITY), 0x00, 0x0B}, []byte("example.org")...), tlvs: []TLV{{ Type: PP2_TYPE_AUTHORITY, Value: []byte("example.org"), }}, }, { name: "empty TLV", raw: []byte{byte(PP2_TYPE_NOOP), 0x00, 0x00}, tlvs: []TLV{{ Type: PP2_TYPE_NOOP, Value: nil, }}, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { if raw, err := JoinTLVs(tc.tlvs); err != nil { t.Fatalf("unexpected error: %v", err) } else if !bytes.Equal(raw, tc.raw) { t.Errorf("expected %#v, got %#v", tc.raw, raw) } }) } } pires-go-proxyproto-04c9ad1/tlvparse/000077500000000000000000000000001514137054000177515ustar00rootroot00000000000000pires-go-proxyproto-04c9ad1/tlvparse/aws.go000066400000000000000000000026371514137054000211020ustar00rootroot00000000000000// Package tlvparse provides helpers for PROXY protocol TLVs. // // Amazon's application extension to TLVs for NLB VPC endpoint services // https://docs.aws.amazon.com/elasticloadbalancing/latest/network/load-balancer-target-groups.html#proxy-protocol package tlvparse import ( "regexp" "github.com/pires/go-proxyproto" ) //nolint:revive // Names follow the spec. const ( // PP2_TYPE_AWS identifies AWS TLV extensions. PP2_TYPE_AWS = 0xEA // PP2_SUBTYPE_AWS_VPCE_ID identifies the VPC endpoint ID subtype. PP2_SUBTYPE_AWS_VPCE_ID = 0x01 ) var vpceRe = regexp.MustCompile("^[A-Za-z0-9-]*$") // IsAWSVPCEndpointID reports whether tlv contains an AWS VPC endpoint ID. func IsAWSVPCEndpointID(tlv proxyproto.TLV) bool { return tlv.Type == PP2_TYPE_AWS && len(tlv.Value) > 0 && tlv.Value[0] == PP2_SUBTYPE_AWS_VPCE_ID } // AWSVPCEndpointID returns the AWS VPC endpoint ID if present. func AWSVPCEndpointID(tlv proxyproto.TLV) (string, error) { if !IsAWSVPCEndpointID(tlv) { return "", proxyproto.ErrIncompatibleTLV } vpce := string(tlv.Value[1:]) if !vpceRe.MatchString(vpce) { return "", proxyproto.ErrMalformedTLV } return vpce, nil } // FindAWSVPCEndpointID returns the first AWS VPC ID in the TLV if it exists and is well-formed. func FindAWSVPCEndpointID(tlvs []proxyproto.TLV) string { for _, tlv := range tlvs { if vpc, err := AWSVPCEndpointID(tlv); err == nil && vpc != "" { return vpc } } return "" } pires-go-proxyproto-04c9ad1/tlvparse/aws_test.go000066400000000000000000000172231514137054000221360ustar00rootroot00000000000000package tlvparse import ( "encoding/binary" "math" "testing" "github.com/pires/go-proxyproto" ) var awsTestCases = []struct { name string raw []byte types []proxyproto.PP2Type valid func(*testing.T, string, []proxyproto.TLV) }{ { name: "VPCE example", // https://github.com/aws/elastic-load-balancing-tools/blob/c8eee30ab991ab4c57dc37d1c58f09f67bd534aa/proprot/tst/com/amazonaws/proprot/Compatibility_AwsNetworkLoadBalancerTest.java#L41..L67 raw: []byte{ 0x0d, 0x0a, 0x0d, 0x0a, /* Start of Sig */ 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a, /* End of Sig */ 0x21, 0x11, 0x00, 0x54, /* ver_cmd, fam and len */ 0xac, 0x1f, 0x07, 0x71, /* Caller src ip */ 0xac, 0x1f, 0x0a, 0x1f, /* Endpoint dst ip */ 0xc8, 0xf2, 0x00, 0x50, /* Proxy src port & dst port */ 0x03, 0x00, 0x04, 0xe8, /* CRC TLV start */ 0xd6, 0x89, 0x2d, 0xea, /* CRC TLV cont, VPCE id TLV start */ 0x00, 0x17, 0x01, 0x76, 0x70, 0x63, 0x65, 0x2d, 0x30, 0x38, 0x64, 0x32, 0x62, 0x66, 0x31, 0x35, 0x66, 0x61, 0x63, 0x35, 0x30, 0x30, 0x31, 0x63, 0x39, 0x04, 0x00, 0x24, /* VPCE id TLV end, NOOP TLV start*/ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, /* NOOP TLV end */ }, types: []proxyproto.PP2Type{proxyproto.PP2_TYPE_CRC32C, PP2_TYPE_AWS, proxyproto.PP2_TYPE_NOOP}, valid: func(t *testing.T, name string, tlvs []proxyproto.TLV) { if !IsAWSVPCEndpointID(tlvs[1]) { t.Fatalf("TestParseV2TLV %s: Expected tlvs[1] to be an AWSVPCEndpointID type", name) } vpce := "vpce-08d2bf15fac5001c9" if vpca, err := AWSVPCEndpointID(tlvs[1]); err != nil { t.Fatalf("TestParseV2TLV %s: Unexpected error when parsing AWSVPCEndpointID", name) } else if vpca != vpce { t.Fatalf("TestParseV2TLV %s: Unexpected VPC ID from tlvs[1] expected %#v, actual %#v", name, vpce, vpca) } if vpca := FindAWSVPCEndpointID(tlvs); vpca == "" { t.Fatalf("TestParseV2TLV %s: Expected to find AWSVPCEndpointID %#v in TLVs", name, vpce) } else if vpca != vpce { t.Fatalf("TestParseV2TLV %s: Unexpected AWSVPCEndpointID from header expected %#v, actual %#v", name, vpce, vpca) } }, }, { name: "VPCE capture", raw: []byte{ 0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a, 0x21, 0x11, 0x00, 0x54, 0xc0, 0xa8, 0x2c, 0x0a, 0xc0, 0xa8, 0x2c, 0x07, 0xcc, 0x3e, 0x24, 0x1b, 0x03, 0x00, 0x04, 0xb9, 0x28, 0x6f, 0xa6, 0xea, 0x00, 0x17, 0x01, 0x76, 0x70, 0x63, 0x65, 0x2d, 0x30, 0x30, 0x65, 0x61, 0x66, 0x63, 0x34, 0x35, 0x38, 0x65, 0x63, 0x39, 0x37, 0x62, 0x38, 0x33, 0x33, 0x04, 0x00, 0x24, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, types: []proxyproto.PP2Type{proxyproto.PP2_TYPE_CRC32C, PP2_TYPE_AWS, proxyproto.PP2_TYPE_NOOP}, valid: func(t *testing.T, name string, tlvs []proxyproto.TLV) { if !IsAWSVPCEndpointID(tlvs[1]) { t.Fatalf("TestParseV2TLV %s: Expected tlvs[1] to be an AWS VPC endpoint ID type", name) } vpce := "vpce-00eafc458ec97b833" if vpca, err := AWSVPCEndpointID(tlvs[1]); err != nil { t.Fatalf("TestParseV2TLV %s: Unexpected error when parsing AWS VPC ID", name) } else if vpca != vpce { t.Fatalf("TestParseV2TLV %s: Unexpected VPC ID from tlvs[1] expected %#v, actual %#v", name, vpce, vpca) } if vpca := FindAWSVPCEndpointID(tlvs); vpca == "" { t.Fatalf("TestParseV2TLV %s: Expected to find VPC ID %#v in TLVs", name, vpce) } else if vpca != vpce { t.Fatalf("TestParseV2TLV %s: Unexpected VPC ID from header expected %#v, actual %#v", name, vpce, vpca) } }, }, } func TestV2TLVAWSVPCEBadChars(t *testing.T) { badVPCE := "vcpe-!?***&&&&&&&" rawTLVs := vpceTLV(badVPCE) tlvs, err := proxyproto.SplitTLVs(rawTLVs) if len(tlvs) != 1 { t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected TLV length expected: %#v, actual: %#v", 1, tlvs) } if err != nil { t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected TLV parsing error %#v", err) } _, err = AWSVPCEndpointID(tlvs[0]) if err != proxyproto.ErrMalformedTLV { t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected error actual: %#v", err) } if FindAWSVPCEndpointID(tlvs) != "" { t.Fatal("TestV2TLVAWSVPCEBadChars: AWSVPCEndpointID unexpectedly found") } rawTLVs = vpceTLV("") tlvs, err = proxyproto.SplitTLVs(rawTLVs) if len(tlvs) != 1 { t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected TLV length expected: %#v, actual: %#v", 1, tlvs) } if err != nil { t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected TLV parsing error %#v", err) } parsedVPCE, err := AWSVPCEndpointID(tlvs[0]) if err != nil { t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected error actual: %#v", err) } if parsedVPCE != "" { t.Fatalf("TestV2TLVAWSVPCEBadChars: found non-empty vpce, actual: %#v", parsedVPCE) } parsedVPCE = FindAWSVPCEndpointID(tlvs) if parsedVPCE != "" { t.Fatal("TestV2TLVAWSVPCEBadChars: AWSVPECID unexpectedly found") } } func TestParseAWSVPCEndpointIDTLVs(t *testing.T) { for _, tc := range awsTestCases { t.Run(tc.name, func(t *testing.T) { tlvs := checkTLVs(t, tc.name, tc.raw, tc.types) tc.valid(t, tc.name, tlvs) }) } } func TestV2TLVAWSUnknownSubtype(t *testing.T) { vpce := "vpce-abc1234" rawTLVs := vpceTLV(vpce) tlvs, err := proxyproto.SplitTLVs(rawTLVs) if len(tlvs) != 1 { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected TLV length expected: %#v, actual: %#v", 1, tlvs) } if err != nil { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected TLV parsing error %#v", err) } avpce, err := AWSVPCEndpointID(tlvs[0]) if err != nil { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected AWSVPCEndpointID error actual: %#v", err) } if avpce != vpce { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected vpce value expected: %#v, actual: %#v", vpce, avpce) } avpce = FindAWSVPCEndpointID(tlvs) if avpce == "" { t.Fatal("TestV2TLVAWSUnknownSubtype: AWSVPCEndpointID unexpectedly missing") } if avpce != vpce { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected AWSVPCEndpointID value expected: %#v, actual: %#v", vpce, avpce) } subtypeIndex := 3 // Sanity check if rawTLVs[subtypeIndex] != PP2_SUBTYPE_AWS_VPCE_ID { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected subtype expected %x, actual %x", PP2_SUBTYPE_AWS_VPCE_ID, rawTLVs[subtypeIndex]) } rawTLVs[subtypeIndex] = PP2_SUBTYPE_AWS_VPCE_ID + 1 tlvs, err = proxyproto.SplitTLVs(rawTLVs) if len(tlvs) != 1 { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected TLV length expected: %#v, actual: %#v", 1, tlvs) } if err != nil { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected TLV parsing error %#v", err) } if IsAWSVPCEndpointID(tlvs[0]) { t.Fatalf("TestV2TLVAWSUnknownSubtype: AWSVPCEType() unexpectedly true after changing subtype") } _, err = AWSVPCEndpointID(tlvs[0]) if err != proxyproto.ErrIncompatibleTLV { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected AWSVPCEndpointID error expected %#v, actual: %#v", proxyproto.ErrIncompatibleTLV, err) } if FindAWSVPCEndpointID(tlvs) != "" { t.Fatal("TestV2TLVAWSUnknownSubtype: AWSVPCEndpointID unexpectedly exists despite invalid subtype") } } func vpceTLV(vpce string) []byte { tlv := []byte{ PP2_TYPE_AWS, 0x00, 0x00, PP2_SUBTYPE_AWS_VPCE_ID, } if len(vpce) > math.MaxUint16-1 { panic("vpce too long for TLV") } //nolint:gosec // Length is bounded above. binary.BigEndian.PutUint16(tlv[1:3], uint16(len(vpce)+1)) // +1 for subtype return append(tlv, []byte(vpce)...) } pires-go-proxyproto-04c9ad1/tlvparse/azure.go000066400000000000000000000034721514137054000214340ustar00rootroot00000000000000// Azure's application extension to TLVs for Private Link Services // https://docs.microsoft.com/en-us/azure/private-link/private-link-service-overview#getting-connection-information-using-tcp-proxy-v2 package tlvparse import ( "encoding/binary" "github.com/pires/go-proxyproto" ) //nolint:revive // Names follow the spec. const ( // PP2_TYPE_AZURE identifies Azure TLV extensions. PP2_TYPE_AZURE = 0xEE // PP2_SUBTYPE_AZURE_PRIVATEENDPOINT_LINKID identifies the Private Endpoint LinkID subtype. PP2_SUBTYPE_AZURE_PRIVATEENDPOINT_LINKID = 0x01 ) // IsAzurePrivateEndpointLinkID returns true if given TLV matches Azure Private Endpoint LinkID format. func isAzurePrivateEndpointLinkID(tlv proxyproto.TLV) bool { return tlv.Type == PP2_TYPE_AZURE && len(tlv.Value) == 5 && tlv.Value[0] == PP2_SUBTYPE_AZURE_PRIVATEENDPOINT_LINKID } // AzurePrivateEndpointLinkID returns linkID if given TLV matches Azure Private Endpoint LinkID format // // Format description: // // Field Length (Octets) Description // Type 1 PP2_TYPE_AZURE (0xEE) // Length 2 Length of value // Value 1 PP2_SUBTYPE_AZURE_PRIVATEENDPOINT_LINKID (0x01) // 4 UINT32 (4 bytes) representing the LINKID of the private endpoint. Encoded in little endian format. func azurePrivateEndpointLinkID(tlv proxyproto.TLV) (uint32, error) { if !isAzurePrivateEndpointLinkID(tlv) { return 0, proxyproto.ErrIncompatibleTLV } linkID := binary.LittleEndian.Uint32(tlv.Value[1:]) return linkID, nil } // FindAzurePrivateEndpointLinkID returns the first Azure Private Endpoint LinkID if it exists in the TLV collection // and a boolean indicating if it was found. func FindAzurePrivateEndpointLinkID(tlvs []proxyproto.TLV) (uint32, bool) { for _, tlv := range tlvs { if linkID, err := azurePrivateEndpointLinkID(tlv); err == nil { return linkID, true } } return 0, false } pires-go-proxyproto-04c9ad1/tlvparse/azure_test.go000066400000000000000000000044351514137054000224730ustar00rootroot00000000000000package tlvparse import ( "testing" "github.com/pires/go-proxyproto" ) func TestFindAzurePrivateEndpointLinkID(t *testing.T) { tests := []struct { name string tlvs []proxyproto.TLV wantLinkID uint32 wantFound bool }{ { name: "nil TLVs", tlvs: nil, wantLinkID: 0, wantFound: false, }, { name: "empty TLVs", tlvs: []proxyproto.TLV{}, wantLinkID: 0, wantFound: false, }, { name: "AWS VPC endpoint ID", tlvs: []proxyproto.TLV{ { Type: 0xEA, Value: []byte{0x01, 0x76, 0x70, 0x63, 0x65, 0x2d, 0x61, 0x62, 0x63, 0x31, 0x32, 0x33}, }, }, wantLinkID: 0, wantFound: false, }, { name: "Azure but wrong subtype", tlvs: []proxyproto.TLV{ { Type: 0xEE, Value: []byte{0x02, 0x01, 0x01, 0x01, 0x01}, }, }, wantLinkID: 0, wantFound: false, }, { name: "Azure but wrong length", tlvs: []proxyproto.TLV{ { Type: 0xEE, Value: []byte{0x02, 0x01, 0x01}, }, }, wantLinkID: 0, wantFound: false, }, { name: "Azure link ID", tlvs: []proxyproto.TLV{ { Type: 0xEE, Value: []byte{0x1, 0xc1, 0x45, 0x0, 0x21}, }, }, wantLinkID: 0x210045c1, wantFound: true, }, { name: "Multiple TLVs", tlvs: []proxyproto.TLV{ { // AWS Type: 0xEA, Value: []byte{0x01, 0x76, 0x70, 0x63, 0x65, 0x2d, 0x61, 0x62, 0x63, 0x31, 0x32, 0x33}, }, { // Azure but wrong subtype Type: 0xEE, Value: []byte{0x02, 0x01, 0x01, 0x01, 0x01}, }, { // Azure but wrong length Type: 0xEE, Value: []byte{0x02, 0x01, 0x01}, }, { // Correct Type: 0xEE, Value: []byte{0x1, 0xc1, 0x45, 0x0, 0x21}, }, { // Also correct, but second in line Type: 0xEE, Value: []byte{0x1, 0xc1, 0x45, 0x0, 0x22}, }, }, wantLinkID: 0x210045c1, wantFound: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { gotLinkID, gotFound := FindAzurePrivateEndpointLinkID(tt.tlvs) if gotFound != tt.wantFound { t.Errorf("FindAzurePrivateEndpointLinkID() got1 = %v, want %v", gotFound, tt.wantFound) } if gotLinkID != tt.wantLinkID { t.Errorf("FindAzurePrivateEndpointLinkID() got = %v, want %v", gotLinkID, tt.wantLinkID) } }) } } pires-go-proxyproto-04c9ad1/tlvparse/gcp.go000066400000000000000000000027011514137054000210510ustar00rootroot00000000000000package tlvparse import ( "encoding/binary" "github.com/pires/go-proxyproto" ) //nolint:revive // Names follow the spec. const ( // PP2_TYPE_GCP indicates a Google Cloud Platform header. PP2_TYPE_GCP proxyproto.PP2Type = 0xE0 ) // ExtractPSCConnectionID returns the first PSC Connection ID in the TLV if it exists and is well-formed and // a bool indicating one was found. func ExtractPSCConnectionID(tlvs []proxyproto.TLV) (uint64, bool) { for _, tlv := range tlvs { if linkID, err := pscConnectionID(tlv); err == nil { return linkID, true } } return 0, false } // pscConnectionID returns the ID of a GCP PSC extension TLV or errors with ErrIncompatibleTLV or // ErrMalformedTLV if it's the wrong TLV type or is malformed. // // Field Length (bytes) Description // Type 1 PP2_TYPE_GCP (0xE0) // Length 2 Length of value (always 0x0008) // Value 8 The 8-byte PSC Connection ID (decode to uint64; big endian) // // For example proxyproto.TLV{Type:0xea, Length:8, Value:[]byte{0xff, 0xff, 0xff, 0xff, 0xc0, 0xa8, 0x64, 0x02}} // will be decoded as 18446744072646845442. // // See https://cloud.google.com/vpc/docs/configure-private-service-connect-producer func pscConnectionID(t proxyproto.TLV) (uint64, error) { if !isPSCConnectionID(t) { return 0, proxyproto.ErrIncompatibleTLV } linkID := binary.BigEndian.Uint64(t.Value) return linkID, nil } func isPSCConnectionID(t proxyproto.TLV) bool { return t.Type == PP2_TYPE_GCP && len(t.Value) == 8 } pires-go-proxyproto-04c9ad1/tlvparse/gcp_test.go000066400000000000000000000035561514137054000221210ustar00rootroot00000000000000package tlvparse import ( "testing" "github.com/pires/go-proxyproto" ) func TestExtractPSCConnectionID(t *testing.T) { tests := []struct { name string tlvs []proxyproto.TLV wantPSCConnectionID uint64 wantFound bool }{ { name: "nil TLVs", tlvs: nil, wantFound: false, }, { name: "empty TLVs", tlvs: []proxyproto.TLV{}, wantFound: false, }, { name: "AWS VPC endpoint ID", tlvs: []proxyproto.TLV{ { Type: 0xEA, Value: []byte{0x01, 0x76, 0x70, 0x63, 0x65, 0x2d, 0x61, 0x62, 0x63, 0x31, 0x32, 0x33}, }, }, wantFound: false, }, { name: "GCP link ID", tlvs: []proxyproto.TLV{ { Type: PP2_TYPE_GCP, Value: []byte{'\xff', '\xff', '\xff', '\xff', '\xc0', '\xa8', '\x64', '\x02'}, }, }, wantPSCConnectionID: 18446744072646845442, wantFound: true, }, { name: "Multiple TLVs", tlvs: []proxyproto.TLV{ { // AWS Type: 0xEA, Value: []byte{0x01, 0x76, 0x70, 0x63, 0x65, 0x2d, 0x61, 0x62, 0x63, 0x31, 0x32, 0x33}, }, { // Azure Type: 0xEE, Value: []byte{0x02, 0x01, 0x01, 0x01, 0x01}, }, { // GCP but wrong length Type: 0xE0, Value: []byte{0xff, 0xff, 0xff}, }, { // Correct Type: 0xE0, Value: []byte{'\xff', '\xff', '\xff', '\xff', '\xc0', '\xa8', '\x64', '\x02'}, }, }, wantPSCConnectionID: 18446744072646845442, wantFound: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { linkID, hasLinkID := ExtractPSCConnectionID(tt.tlvs) if hasLinkID != tt.wantFound { t.Errorf("ExtractPSCConnectionID() got1 = %v, want %v", hasLinkID, tt.wantFound) } if linkID != tt.wantPSCConnectionID { t.Errorf("ExtractPSCConnectionID() got = %v, want %v", linkID, tt.wantPSCConnectionID) } }) } } pires-go-proxyproto-04c9ad1/tlvparse/ssl.go000066400000000000000000000142121514137054000211010ustar00rootroot00000000000000package tlvparse import ( "encoding/binary" "unicode" "unicode/utf8" "github.com/pires/go-proxyproto" ) // pp2_tlv_ssl.client bit fields. // //nolint:revive // Names follow the PROXY protocol spec. const ( // PP2_BITFIELD_CLIENT_SSL indicates the client used SSL/TLS. PP2_BITFIELD_CLIENT_SSL uint8 = 0x01 // PP2_BITFIELD_CLIENT_CERT_CONN indicates cert on the connection. PP2_BITFIELD_CLIENT_CERT_CONN uint8 = 0x02 // PP2_BITFIELD_CLIENT_CERT_SESS indicates cert in the session. PP2_BITFIELD_CLIENT_CERT_SESS uint8 = 0x04 ) const ( // tlvSSLMinLen is the minimum length of a SSL TLV. tlvSSLMinLen = 5 // len(pp2_tlv_ssl.client) + len(pp2_tlv_ssl.verify) ) // PP2SSL represents the PP2_TYPE_SSL TLV and its subtypes. // // See section 2.2.5 of the PROXY protocol spec. /* struct pp2_tlv_ssl { uint8_t client; uint32_t verify; struct pp2_tlv sub_tlv[0]; }; */ type PP2SSL struct { // The Client field is made of a bit field from the following values, // indicating which element is present: PP2_BITFIELD_CLIENT_SSL, // PP2_BITFIELD_CLIENT_CERT_CONN, PP2_BITFIELD_CLIENT_CERT_SESS Client uint8 // Verify will be zero if the client presented a certificate // and it was successfully verified, and non-zero otherwise. Verify uint32 TLV []proxyproto.TLV } // Verified is true if the client presented a certificate and it was successfully verified. func (s PP2SSL) Verified() bool { return s.Verify == 0 } // ClientSSL indicates that the client connected over SSL/TLS. When true, SSLVersion will return the version. func (s PP2SSL) ClientSSL() bool { return s.Client&PP2_BITFIELD_CLIENT_SSL == PP2_BITFIELD_CLIENT_SSL } // ClientCertConn indicates that the client provided a certificate over the current connection. func (s PP2SSL) ClientCertConn() bool { return s.Client&PP2_BITFIELD_CLIENT_CERT_CONN == PP2_BITFIELD_CLIENT_CERT_CONN } // ClientCertSess indicates that the client provided a certificate at least once over the TLS session this // connection belongs to. func (s PP2SSL) ClientCertSess() bool { return s.Client&PP2_BITFIELD_CLIENT_CERT_SESS == PP2_BITFIELD_CLIENT_CERT_SESS } // SSLVersion returns the US-ASCII string representation of the TLS version and whether that extension exists. func (s PP2SSL) SSLVersion() (string, bool) { for _, tlv := range s.TLV { if tlv.Type == proxyproto.PP2_SUBTYPE_SSL_VERSION { return string(tlv.Value), true } } return "", false } // SSLCipher returns the US-ASCII string representation of the used TLS cipher and whether that extension exists. func (s PP2SSL) SSLCipher() (string, bool) { for _, tlv := range s.TLV { if tlv.Type == proxyproto.PP2_SUBTYPE_SSL_CIPHER { return string(tlv.Value), true } } return "", false } // Marshal formats the PP2SSL structure as a TLV. func (s PP2SSL) Marshal() (proxyproto.TLV, error) { v := make([]byte, 5) v[0] = s.Client binary.BigEndian.PutUint32(v[1:5], s.Verify) tlvs, err := proxyproto.JoinTLVs(s.TLV) if err != nil { return proxyproto.TLV{}, err } v = append(v, tlvs...) return proxyproto.TLV{ Type: proxyproto.PP2_TYPE_SSL, Value: v, }, nil } // ClientCN returns the string representation (in UTF8) of the Common Name field (OID: 2.5.4.3) of the client // certificate's Distinguished Name and whether that extension exists. func (s PP2SSL) ClientCN() (string, bool) { for _, tlv := range s.TLV { if tlv.Type == proxyproto.PP2_SUBTYPE_SSL_CN { return string(tlv.Value), true } } return "", false } // ClientCert returns the raw X.509 client certificate encoded in ASN.1 DER and // whether that extension exists. func (s PP2SSL) ClientCert() ([]byte, bool) { for _, tlv := range s.TLV { if tlv.Type == proxyproto.PP2_SUBTYPE_SSL_CLIENT_CERT { return tlv.Value, true } } return nil, false } // IsSSL reports whether the TLV is of SSL type. func IsSSL(t proxyproto.TLV) bool { return t.Type == proxyproto.PP2_TYPE_SSL && len(t.Value) >= tlvSSLMinLen } // SSL returns the pp2_tlv_ssl from section 2.2.5 or errors with ErrIncompatibleTLV or ErrMalformedTLV. func SSL(t proxyproto.TLV) (PP2SSL, error) { ssl := PP2SSL{} if !IsSSL(t) { return ssl, proxyproto.ErrIncompatibleTLV } if len(t.Value) < tlvSSLMinLen { return ssl, proxyproto.ErrMalformedTLV } ssl.Client = t.Value[0] ssl.Verify = binary.BigEndian.Uint32(t.Value[1:5]) var err error ssl.TLV, err = proxyproto.SplitTLVs(t.Value[5:]) if err != nil { return PP2SSL{}, err } versionFound := !ssl.ClientSSL() for _, tlv := range ssl.TLV { switch tlv.Type { case proxyproto.PP2_SUBTYPE_SSL_VERSION: /* The PP2_CLIENT_SSL flag indicates that the client connected over SSL/TLS. When this field is present, the US-ASCII string representation of the TLS version is appended at the end of the field in the TLV format using the type PP2_SUBTYPE_SSL_VERSION. */ if len(tlv.Value) == 0 || !isASCII(tlv.Value) { return PP2SSL{}, proxyproto.ErrMalformedTLV } versionFound = true case proxyproto.PP2_SUBTYPE_SSL_CN: /* In all cases, the string representation (in UTF8) of the Common Name field (OID: 2.5.4.3) of the client certificate's Distinguished Name, is appended using the TLV format and the type PP2_SUBTYPE_SSL_CN. E.g. "example.com". */ if len(tlv.Value) == 0 || !utf8.Valid(tlv.Value) { return PP2SSL{}, proxyproto.ErrMalformedTLV } case proxyproto.PP2_SUBTYPE_SSL_CIPHER: /* The second level TLV PP2_SUBTYPE_SSL_CIPHER provides the US-ASCII string name of the used cipher, for example "ECDHE-RSA-AES128-GCM-SHA256". */ if len(tlv.Value) == 0 || !isASCII(tlv.Value) { return PP2SSL{}, proxyproto.ErrMalformedTLV } } } if !versionFound { return PP2SSL{}, proxyproto.ErrMalformedTLV } return ssl, nil } // FindSSL returns the first PP2SSL if it exists and is well formed. func FindSSL(tlvs []proxyproto.TLV) (PP2SSL, bool) { for _, t := range tlvs { if ssl, err := SSL(t); err == nil { return ssl, true } } return PP2SSL{}, false } // isASCII checks whether a byte slice has all characters that fit in the ascii character set, including the null byte. func isASCII(b []byte) bool { for _, c := range b { if c > unicode.MaxASCII { return false } } return true } pires-go-proxyproto-04c9ad1/tlvparse/ssl_test.go000066400000000000000000000125371514137054000221500ustar00rootroot00000000000000package tlvparse import ( "reflect" "testing" "github.com/pires/go-proxyproto" ) // tlsVersion13 is the TLS version 1.3 string. const tlsVersion13 string = "TLSv1.3" var testCases = []struct { name string raw []byte types []proxyproto.PP2Type valid func(*testing.T, string, []proxyproto.TLV) }{ { name: "SSL haproxy cn", raw: []byte{ 0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a, 0x21, 0x11, 0x00, 0x40, 0x7f, 0x00, 0x00, 0x01, 0x7f, 0x00, 0x00, 0x01, 0xcc, 0x8a, 0x23, 0x2e, 0x20, 0x00, 0x31, 0x07, 0x00, 0x00, 0x00, 0x00, 0x21, 0x00, 0x07, 0x54, 0x4c, 0x53, 0x76, 0x31, 0x2e, 0x33, 0x22, 0x00, 0x1f, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x43, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x20, 0x4e, 0x61, 0x6d, 0x65, 0x20, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x20, 0x43, 0x65, 0x72, 0x74, }, types: []proxyproto.PP2Type{proxyproto.PP2_TYPE_SSL}, valid: func(t *testing.T, name string, tlvs []proxyproto.TLV) { if !IsSSL(tlvs[0]) { t.Fatalf("TestParseV2TLV %s: Expected tlvs[0] to be the SSL type", name) } ssl, err := SSL(tlvs[0]) if err != nil { t.Fatalf("TestParseV2TLV %s: Unexpected error when parsing SSL %#v", name, err) } if !ssl.ClientSSL() { t.Fatalf("TestParseV2TLV %s: Expected ClientSSL() to be true", name) } if !ssl.ClientCertConn() { t.Fatalf("TestParseV2TLV %s: Expected ClientCertConn() to be true", name) } if !ssl.ClientCertSess() { t.Fatalf("TestParseV2TLV %s: Expected ClientCertSess() to be true", name) } ecn := "Example Common Name Client Cert" if acn, ok := ssl.ClientCN(); !ok { t.Fatalf("TestParseV2TLV %s: Expected ClientCN to exist", name) } else if acn != ecn { t.Fatalf("TestParseV2TLV %s: Unexpected ClientCN expected %#v, actual %#v", name, ecn, acn) } esslVer := tlsVersion13 if asslVer, ok := ssl.SSLVersion(); !ok { t.Fatalf("TestParseV2TLV %s: Expected SSLVersion to exist", name) } else if asslVer != esslVer { t.Fatalf("TestParseV2TLV %s: Unexpected SSLVersion expected %#v, actual %#v", name, esslVer, asslVer) } if _, ok := ssl.SSLCipher(); ok { t.Fatalf("TestParseV2TLV %s: Unexpected SSLCipher", name) } if !ssl.Verified() { t.Fatalf("TestParseV2TLV %s: Expected Verified to be true", name) } }, }, { name: "SSL haproxy cipher", raw: []byte{ 0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a, 0x21, 0x21, 0x00, 0x4f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x0a, 0x01, 0x5b, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x0a, 0x01, 0x01, 0x9f, 0xf4, 0x7c, 0x01, 0xbb, 0x20, 0x00, 0x28, 0x01, 0x00, 0x00, 0x00, 0x00, 0x21, 0x00, 0x07, 0x54, 0x4c, 0x53, 0x76, 0x31, 0x2e, 0x33, 0x23, 0x00, 0x16, 0x54, 0x4c, 0x53, 0x5f, 0x41, 0x45, 0x53, 0x5f, 0x32, 0x35, 0x36, 0x5f, 0x47, 0x43, 0x4d, 0x5f, 0x53, 0x48, 0x41, 0x33, 0x38, 0x34, }, types: []proxyproto.PP2Type{proxyproto.PP2_TYPE_SSL}, valid: func(t *testing.T, name string, tlvs []proxyproto.TLV) { if !IsSSL(tlvs[0]) { t.Fatalf("TestParseV2TLV %s: Expected tlvs[0] to be the SSL type", name) } ssl, err := SSL(tlvs[0]) if err != nil { t.Fatalf("TestParseV2TLV %s: Unexpected error when parsing SSL %#v", name, err) } if !ssl.ClientSSL() { t.Fatalf("TestParseV2TLV %s: Expected ClientSSL() to be true", name) } if ssl.ClientCertConn() { t.Fatalf("TestParseV2TLV %s: Expected ClientCertConn() to be false", name) } if ssl.ClientCertSess() { t.Fatalf("TestParseV2TLV %s: Expected ClientCertSess() to be false", name) } if _, ok := ssl.ClientCN(); ok { t.Fatalf("TestParseV2TLV %s: Expected ClientCN to not exist", name) } esslVer := "TLSv1.3" if asslVer, ok := ssl.SSLVersion(); !ok { t.Fatalf("TestParseV2TLV %s: Expected SSLVersion to exist", name) } else if asslVer != esslVer { t.Fatalf("TestParseV2TLV %s: Unexpected SSLVersion expected %#v, actual %#v", name, esslVer, asslVer) } esslCipher := "TLS_AES_256_GCM_SHA384" if asslCipher, ok := ssl.SSLCipher(); !ok { t.Fatalf("TestParseV2TLV %s: Expected SSLCipher to exist", name) } else if asslCipher != esslCipher { t.Fatalf("TestParseV2TLV %s: Unexpected SSLCipher expected %#v, actual %#v", name, esslCipher, asslCipher) } }, }, } func TestParseV2TLV(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { tlvs := checkTLVs(t, tc.name, tc.raw, tc.types) tc.valid(t, tc.name, tlvs) }) } } func TestPP2SSLMarshal(t *testing.T) { ver := "TLSv1.3" cn := "example.org" pp2 := PP2SSL{ Client: PP2_BITFIELD_CLIENT_SSL, Verify: 0, TLV: []proxyproto.TLV{ { Type: proxyproto.PP2_SUBTYPE_SSL_VERSION, Value: []byte(ver), }, { Type: proxyproto.PP2_SUBTYPE_SSL_CN, Value: []byte(cn), }, }, } raw := []byte{0x1, 0x0, 0x0, 0x0, 0x0, 0x21, 0x0, 0x7, 0x54, 0x4c, 0x53, 0x76, 0x31, 0x2e, 0x33, 0x22, 0x0, 0xb, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x6f, 0x72, 0x67} want := proxyproto.TLV{ Type: proxyproto.PP2_TYPE_SSL, Value: raw, } tlv, err := pp2.Marshal() if err != nil { t.Fatalf("PP2SSL.Marshal() = %v", err) } if !reflect.DeepEqual(tlv, want) { t.Errorf("PP2SSL.Marshal() = %#v, want %#v", tlv, want) } } pires-go-proxyproto-04c9ad1/tlvparse/test.go000066400000000000000000000013371514137054000212630ustar00rootroot00000000000000package tlvparse import ( "bufio" "bytes" "testing" "github.com/pires/go-proxyproto" ) func checkTLVs(t *testing.T, name string, raw []byte, expected []proxyproto.PP2Type) []proxyproto.TLV { header, err := proxyproto.Read(bufio.NewReader(bytes.NewReader(raw))) if err != nil { t.Fatalf("%s: Unexpected error reading header %#v", name, err) } tlvs, err := header.TLVs() if err != nil { t.Fatalf("%s: Unexpected error splitting TLVS %#v", name, err) } if len(tlvs) != len(expected) { t.Fatalf("%s: Expected %d TLVs, actual %d", name, len(expected), len(tlvs)) } for i, et := range expected { if at := tlvs[i].Type; at != et { t.Fatalf("%s: Expected type %X, actual %X", name, et, at) } } return tlvs } pires-go-proxyproto-04c9ad1/v1.go000066400000000000000000000160411514137054000167700ustar00rootroot00000000000000package proxyproto import ( "bufio" "bytes" "fmt" "net" "net/netip" "strconv" "strings" ) const ( crlf = "\r\n" separator = " " ) func initVersion1() *Header { header := new(Header) header.Version = 1 // Command doesn't exist in v1 header.Command = PROXY return header } func parseVersion1(reader *bufio.Reader) (*Header, error) { //The header cannot be more than 107 bytes long. Per spec: // // (...) // - worst case (optional fields set to 0xff) : // "PROXY UNKNOWN ffff:f...f:ffff ffff:f...f:ffff 65535 65535\r\n" // => 5 + 1 + 7 + 1 + 39 + 1 + 39 + 1 + 5 + 1 + 5 + 2 = 107 chars // // So a 108-byte buffer is always enough to store all the line and a // trailing zero for string processing. // // It must also be CRLF terminated, as above. The header does not otherwise // contain a CR or LF byte. // // ISSUE #69 // We can't use Peek here as it will block trying to fill the buffer, which // will never happen if the header is TCP4 or TCP6 (max. 56 and 104 bytes // respectively) and the server is expected to speak first. // // Similarly, we can't use ReadString or ReadBytes as these will keep reading // until the delimiter is found; an abusive client could easily disrupt a // server by sending a large amount of data that do not contain a LF byte. // Another means of attack would be to start connections and simply not send // data after the initial PROXY signature bytes, accumulating a large // number of blocked goroutines on the server. ReadSlice will also block for // a delimiter when the internal buffer does not fill up. // // A plain Read is also problematic since we risk reading past the end of the // header without being able to easily put the excess bytes back into the reader's // buffer (with the current implementation's design). // // So we use a ReadByte loop, which solves the overflow problem and avoids // reading beyond the end of the header. However, we need one more trick to harden // against partial header attacks (slow loris) - per spec: // // (..) The sender must always ensure that the header is sent at once, so that // the transport layer maintains atomicity along the path to the receiver. The // receiver may be tolerant to partial headers or may simply drop the connection // when receiving a partial header. Recommendation is to be tolerant, but // implementation constraints may not always easily permit this. // // We are subject to such implementation constraints. So we return an error if // the header cannot be fully extracted with a single read of the underlying // reader. buf := make([]byte, 0, 107) for { b, err := reader.ReadByte() if err != nil { return nil, fmt.Errorf("%w: %w", ErrCantReadVersion1Header, err) } buf = append(buf, b) if b == '\n' { // End of header found break } if len(buf) == 107 { // No delimiter in first 107 bytes return nil, ErrVersion1HeaderTooLong } if reader.Buffered() == 0 { // Header was not buffered in a single read. Since we can't // differentiate between genuine slow writers and DoS agents, // we abort. On healthy networks, this should never happen. return nil, ErrCantReadVersion1Header } } // Check for CR before LF. if len(buf) < 2 || buf[len(buf)-2] != '\r' { return nil, ErrLineMustEndWithCrlf } // Check full signature. tokens := strings.Split(string(buf[:len(buf)-2]), separator) // Expect at least 2 tokens: "PROXY" and the transport protocol. if len(tokens) < 2 { return nil, ErrCantReadAddressFamilyAndProtocol } // Read address family and protocol var transportProtocol AddressFamilyAndProtocol switch tokens[1] { case "TCP4": transportProtocol = TCPv4 case "TCP6": transportProtocol = TCPv6 case "UNKNOWN": transportProtocol = UNSPEC // doesn't exist in v1 but fits UNKNOWN default: return nil, ErrCantReadAddressFamilyAndProtocol } // Expect 6 tokens only when UNKNOWN is not present. if transportProtocol != UNSPEC && len(tokens) < 6 { return nil, ErrCantReadAddressFamilyAndProtocol } // When a signature is found, allocate a v1 header with Command set to PROXY. // Command doesn't exist in v1 but set it for other parts of this library // to rely on it for determining connection details. header := initVersion1() // Transport protocol has been processed already. header.TransportProtocol = transportProtocol // When UNKNOWN, set the command to LOCAL and return early if header.TransportProtocol == UNSPEC { header.Command = LOCAL return header, nil } // Otherwise, continue to read addresses and ports sourceIP, err := parseV1IPAddress(header.TransportProtocol, tokens[2]) if err != nil { return nil, err } destIP, err := parseV1IPAddress(header.TransportProtocol, tokens[3]) if err != nil { return nil, err } sourcePort, err := parseV1PortNumber(tokens[4]) if err != nil { return nil, err } destPort, err := parseV1PortNumber(tokens[5]) if err != nil { return nil, err } header.SourceAddr = &net.TCPAddr{ IP: sourceIP, Port: sourcePort, } header.DestinationAddr = &net.TCPAddr{ IP: destIP, Port: destPort, } return header, nil } func (header *Header) formatVersion1() ([]byte, error) { // As of version 1, only "TCP4" ( \x54 \x43 \x50 \x34 ) for TCP over IPv4, // and "TCP6" ( \x54 \x43 \x50 \x36 ) for TCP over IPv6 are allowed. var proto string switch header.TransportProtocol { case TCPv4: proto = "TCP4" case TCPv6: proto = "TCP6" default: // Unknown connection (short form) return []byte("PROXY UNKNOWN" + crlf), nil } sourceAddr, sourceOK := header.SourceAddr.(*net.TCPAddr) destAddr, destOK := header.DestinationAddr.(*net.TCPAddr) if !sourceOK || !destOK { return nil, ErrInvalidAddress } sourceIP, destIP := sourceAddr.IP, destAddr.IP switch header.TransportProtocol { case TCPv4: sourceIP = sourceIP.To4() destIP = destIP.To4() case TCPv6: sourceIP = sourceIP.To16() destIP = destIP.To16() } if sourceIP == nil || destIP == nil { return nil, ErrInvalidAddress } buf := bytes.NewBuffer(make([]byte, 0, 108)) buf.Write(SIGV1) buf.WriteString(separator) buf.WriteString(proto) buf.WriteString(separator) buf.WriteString(sourceIP.String()) buf.WriteString(separator) buf.WriteString(destIP.String()) buf.WriteString(separator) buf.WriteString(strconv.Itoa(sourceAddr.Port)) buf.WriteString(separator) buf.WriteString(strconv.Itoa(destAddr.Port)) buf.WriteString(crlf) return buf.Bytes(), nil } func parseV1PortNumber(portStr string) (int, error) { port, err := strconv.Atoi(portStr) if err != nil { return 0, fmt.Errorf("%w: %w", ErrInvalidPortNumber, err) } if port < 0 || port > 65535 { return 0, ErrInvalidPortNumber } return port, nil } func parseV1IPAddress(protocol AddressFamilyAndProtocol, addrStr string) (net.IP, error) { addr, err := netip.ParseAddr(addrStr) if err != nil { return nil, fmt.Errorf("%w: %w", ErrInvalidAddress, err) } switch protocol { case TCPv4: if addr.Is4() { return net.IP(addr.AsSlice()), nil } case TCPv6: if addr.Is6() || addr.Is4In6() { return net.IP(addr.AsSlice()), nil } } return nil, ErrInvalidAddress } pires-go-proxyproto-04c9ad1/v1_test.go000066400000000000000000000237011514137054000200300ustar00rootroot00000000000000package proxyproto import ( "bufio" "bytes" "errors" "fmt" "io" "net" "strconv" "strings" "testing" "time" ) var ( IPv4AddressesAndPorts = strings.Join([]string{testLocalhostIP4Addr, testLocalhostIP4Addr, strconv.Itoa(testValidPort), strconv.Itoa(testValidPort)}, separator) IPv4In6AddressesAndPorts = strings.Join([]string{testLocalhostIP4In6Addr, testLocalhostIP4In6Addr, strconv.Itoa(testValidPort), strconv.Itoa(testValidPort)}, separator) IPv4AddressesAndInvalidPorts = strings.Join([]string{testLocalhostIP4Addr, testLocalhostIP4Addr, strconv.Itoa(testInvalidPort), strconv.Itoa(testInvalidPort)}, separator) IPv6AddressesAndPorts = strings.Join([]string{testLocalhostIP6Addr, testLocalhostIP6Addr, strconv.Itoa(testValidPort), strconv.Itoa(testValidPort)}, separator) IPv6LongAddressesAndPorts = strings.Join([]string{testIP6LongAddr, testIP6LongAddr, strconv.Itoa(testValidPort), strconv.Itoa(testValidPort)}, separator) fixtureTCP4V1 = "PROXY TCP4 " + IPv4AddressesAndPorts + crlf + "GET /" fixtureTCP6V1 = "PROXY TCP6 " + IPv6AddressesAndPorts + crlf + "GET /" fixtureTCP4IN6V1 = "PROXY TCP6 " + IPv4In6AddressesAndPorts + crlf + "GET /" fixtureTCP6V1Overflow = "PROXY TCP6 " + IPv6LongAddressesAndPorts fixtureUnknown = "PROXY UNKNOWN" + crlf fixtureUnknownWithAddresses = "PROXY UNKNOWN " + IPv4AddressesAndInvalidPorts + crlf ) var invalidParseV1Tests = []struct { desc string reader *bufio.Reader expectedError error }{ { desc: "no signature", reader: newBufioReader([]byte(testNoProtocol)), expectedError: ErrNoProxyProtocol, }, { desc: "prox", reader: newBufioReader([]byte("PROX")), expectedError: ErrNoProxyProtocol, }, { desc: "proxy lf", reader: newBufioReader([]byte("PROXY \n")), expectedError: ErrLineMustEndWithCrlf, }, { desc: "proxy crlf", reader: newBufioReader([]byte("PROXY " + crlf)), expectedError: ErrCantReadAddressFamilyAndProtocol, }, { desc: "proxy no space crlf", reader: newBufioReader([]byte("PROXY" + crlf)), expectedError: ErrCantReadAddressFamilyAndProtocol, }, { desc: "proxy something crlf", reader: newBufioReader([]byte("PROXY SOMETHING" + crlf)), expectedError: ErrCantReadAddressFamilyAndProtocol, }, { desc: "incomplete signature TCP4", reader: newBufioReader([]byte("PROXY TCP4 " + IPv4AddressesAndPorts)), expectedError: ErrCantReadVersion1Header, }, { desc: "invalid IP address", reader: newBufioReader([]byte("PROXY TCP4 invalid invalid 65533 65533" + crlf)), expectedError: ErrInvalidAddress, }, { desc: "TCP6 with IPv4 addresses", reader: newBufioReader([]byte("PROXY TCP6 " + IPv4AddressesAndPorts + crlf)), expectedError: ErrInvalidAddress, }, { desc: "TCP4 with IPv6 addresses", reader: newBufioReader([]byte("PROXY TCP4 " + IPv6AddressesAndPorts + crlf)), expectedError: ErrInvalidAddress, }, { desc: "TCP4 with IPv4 mapped addresses", reader: newBufioReader([]byte("PROXY TCP4 " + IPv4In6AddressesAndPorts + crlf)), expectedError: ErrInvalidAddress, }, { desc: "TCP4 with invalid port", reader: newBufioReader([]byte("PROXY TCP4 " + IPv4AddressesAndInvalidPorts + crlf)), expectedError: ErrInvalidPortNumber, }, { desc: "header too long", reader: newBufioReader([]byte("PROXY UNKNOWN " + IPv6LongAddressesAndPorts + " " + crlf)), expectedError: ErrVersion1HeaderTooLong, }, } func TestReadV1Invalid(t *testing.T) { for _, tt := range invalidParseV1Tests { t.Run(tt.desc, func(t *testing.T) { if _, err := Read(tt.reader); !errors.Is(err, tt.expectedError) { t.Fatalf("expected %s, actual %v", tt.expectedError, err) } }) } } var validParseAndWriteV1Tests = []struct { desc string reader *bufio.Reader expectedHeader *Header skipWrite bool }{ { desc: "TCP4", reader: bufio.NewReader(strings.NewReader(fixtureTCP4V1)), expectedHeader: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, }, }, { desc: "TCP6", reader: bufio.NewReader(strings.NewReader(fixtureTCP6V1)), expectedHeader: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: v6addr, DestinationAddr: v6addr, }, }, { desc: "TCP4IN6", reader: bufio.NewReader(strings.NewReader(fixtureTCP4IN6V1)), expectedHeader: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: v4addr, DestinationAddr: v4addr, }, // we skip write test because net.ParseIP converts ::ffff:127.0.0.1 to v4 // instead of preserving the v4 in v6 form, so, after serializing the header, // we end up with v6 protocol and a v4 IP which is invalid skipWrite: true, }, { desc: "unknown", reader: bufio.NewReader(strings.NewReader(fixtureUnknown)), expectedHeader: &Header{ Version: 1, Command: LOCAL, TransportProtocol: UNSPEC, SourceAddr: nil, DestinationAddr: nil, }, }, { desc: "unknown with addresses and ports", reader: bufio.NewReader(strings.NewReader(fixtureUnknownWithAddresses)), expectedHeader: &Header{ Version: 1, Command: LOCAL, TransportProtocol: UNSPEC, SourceAddr: nil, DestinationAddr: nil, }, }, } func TestParseV1Valid(t *testing.T) { for _, tt := range validParseAndWriteV1Tests { t.Run(tt.desc, func(t *testing.T) { header, err := Read(tt.reader) if err != nil { t.Fatal("unexpected error", err.Error()) } if !header.EqualsTo(tt.expectedHeader) { t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, header) } }) } } func TestWriteV1Valid(t *testing.T) { for _, tt := range validParseAndWriteV1Tests { if tt.skipWrite { continue } t.Run(tt.desc, func(t *testing.T) { var b bytes.Buffer w := bufio.NewWriter(&b) if _, err := tt.expectedHeader.WriteTo(w); err != nil { t.Fatal("unexpected error ", err) } if err := w.Flush(); err != nil { t.Fatal("unexpected error ", err) } // Read written bytes to validate written header r := bufio.NewReader(&b) newHeader, err := Read(r) if err != nil { t.Fatal("unexpected error ", err) } if !newHeader.EqualsTo(tt.expectedHeader) { t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, newHeader) } }) } } // Tests for parseVersion1 overflow - issue #69. type dataSource struct { NBytes int NRead int } func (ds *dataSource) Read(b []byte) (int, error) { if ds.NRead >= ds.NBytes { return 0, io.EOF } avail := ds.NBytes - ds.NRead avail = min(avail, len(b)) for i := 0; i < avail; i++ { b[i] = 0x20 } ds.NRead += avail return avail, nil } func TestParseVersion1Overflow(t *testing.T) { ds := &dataSource{} reader := bufio.NewReader(ds) bufSize := reader.Size() ds.NBytes = bufSize * 16 _, _ = parseVersion1(reader) if ds.NRead > bufSize { t.Fatalf("read: expected max %d bytes, actual %d\n", bufSize, ds.NRead) } } func listen(t *testing.T) *Listener { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } return &Listener{Listener: l} } func client(t *testing.T, addr, header string, length int, terminate bool, wait time.Duration, done chan struct{}, result chan error, ) { c, err := net.Dial("tcp", addr) if err != nil { result <- fmt.Errorf("dial: %w", err) return } t.Cleanup(func() { if err := c.Close(); err != nil { t.Errorf("failed to close connection: %v", err) } }) if terminate && length < 2 { length = 2 } buf := make([]byte, len(header)+length) copy(buf, []byte(header)) for i := 0; i < length-2; i++ { buf[i+len(header)] = 0x20 } if terminate { copy(buf[len(header)+length-2:], []byte(crlf)) } n, err := c.Write(buf) if err != nil { result <- fmt.Errorf("write: %w", err) return } if n != len(buf) { result <- errors.New("write; short write") return } close(result) time.Sleep(wait) close(done) } func TestVersion1Overflow(t *testing.T) { done := make(chan struct{}) cliResult := make(chan error) l := listen(t) go client(t, l.Addr().String(), fixtureTCP6V1Overflow, 10240, true, 10*time.Second, done, cliResult) c, err := l.Accept() if err != nil { t.Fatalf("accept: %v", err) } b := []byte{} _, err = c.Read(b) if err == nil { t.Fatalf("net.Conn: no error reported for oversized header") } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestVersion1SlowLoris(t *testing.T) { done := make(chan struct{}) cliResult := make(chan error) timeout := make(chan error) l := listen(t) go client(t, l.Addr().String(), fixtureTCP6V1Overflow, 0, false, 10*time.Second, done, cliResult) c, err := l.Accept() if err != nil { t.Fatalf("accept: %v", err) } go func() { b := []byte{} _, err = c.Read(b) timeout <- err }() select { case <-done: t.Fatalf("net.Conn: reader still blocked after 10 seconds") case err := <-timeout: if err == nil { t.Fatalf("net.Conn: no error reported for incomplete header") } } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestVersion1SlowLorisOverflow(t *testing.T) { done := make(chan struct{}) cliResult := make(chan error) timeout := make(chan error) l := listen(t) go client(t, l.Addr().String(), fixtureTCP6V1Overflow, 10240, false, 10*time.Second, done, cliResult) c, err := l.Accept() if err != nil { t.Fatalf("accept: %v", err) } go func() { b := []byte{} _, err = c.Read(b) timeout <- err }() select { case <-done: t.Fatalf("net.Conn: reader still blocked after 10 seconds") case err := <-timeout: if err == nil { t.Fatalf("net.Conn: no error reported for incomplete and overflowed header") } } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } pires-go-proxyproto-04c9ad1/v2.go000066400000000000000000000202721514137054000167720ustar00rootroot00000000000000package proxyproto import ( "bufio" "bytes" "encoding/binary" "errors" "fmt" "io" "math" "net" ) // maxV2HeaderSize is the maximum acceptable size of a V2 header. // // A V2 header may be at most 16 bytes + 64KiB large. We enforce a lower limit // to mitigate memory allocation DoS while allowing real-world legitimate // headers. PP2_SUBTYPE_SSL_CLIENT_CERT is typically between 1 and 2KiB, so we // use a 4KiB limit to leave some room for other TLVs. const maxV2HeaderSize = 4096 var ( lengthUnspec = uint16(0) lengthV4 = uint16(12) lengthV6 = uint16(36) lengthUnix = uint16(216) lengthUnspecBytes = func() []byte { a := make([]byte, 2) binary.BigEndian.PutUint16(a, lengthUnspec) return a }() lengthV4Bytes = func() []byte { a := make([]byte, 2) binary.BigEndian.PutUint16(a, lengthV4) return a }() lengthV6Bytes = func() []byte { a := make([]byte, 2) binary.BigEndian.PutUint16(a, lengthV6) return a }() lengthUnixBytes = func() []byte { a := make([]byte, 2) binary.BigEndian.PutUint16(a, lengthUnix) return a }() errUint16Overflow = errors.New("proxyproto: uint16 overflow") ) type _ports struct { SrcPort uint16 DstPort uint16 } type _addr4 struct { Src [4]byte Dst [4]byte SrcPort uint16 DstPort uint16 } type _addr6 struct { Src [16]byte Dst [16]byte _ports } type _addrUnix struct { Src [108]byte Dst [108]byte } func parseVersion2(reader *bufio.Reader) (header *Header, err error) { // Skip first 12 bytes (signature) for range 12 { if _, err = reader.ReadByte(); err != nil { return nil, fmt.Errorf("%w: %w", ErrCantReadProtocolVersionAndCommand, err) } } header = new(Header) header.Version = 2 // Read the 13th byte, protocol version and command b13, err := reader.ReadByte() if err != nil { return nil, fmt.Errorf("%w: %w", ErrCantReadProtocolVersionAndCommand, err) } header.Command = ProtocolVersionAndCommand(b13) if _, ok := supportedCommand[header.Command]; !ok { return nil, ErrUnsupportedProtocolVersionAndCommand } // Read the 14th byte, address family and protocol b14, err := reader.ReadByte() if err != nil { return nil, fmt.Errorf("%w: %w", ErrCantReadAddressFamilyAndProtocol, err) } header.TransportProtocol = AddressFamilyAndProtocol(b14) // UNSPEC is only supported when LOCAL is set. if header.TransportProtocol == UNSPEC && header.Command != LOCAL { return nil, ErrUnsupportedAddressFamilyAndProtocol } // Make sure there are bytes available as specified in length var length uint16 if err := binary.Read(reader, binary.BigEndian, &length); err != nil { return nil, fmt.Errorf("%w: %w", ErrCantReadLength, err) } if !header.validateLength(length) { return nil, ErrInvalidLength } // Return early if the length is zero, which means that // there's no address information and TLVs present for UNSPEC. if length == 0 { return header, nil } if length > maxV2HeaderSize { return nil, ErrInvalidLength } // Length-limited reader for payload section payloadReader := io.LimitReader(reader, int64(length)).(*io.LimitedReader) // Read addresses and ports for protocols other than UNSPEC. // Ignore address information for UNSPEC, and skip straight to read TLVs, // since the length is greater than zero. if header.TransportProtocol != UNSPEC { if header.TransportProtocol.IsIPv4() { var addr _addr4 if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil { return nil, fmt.Errorf("%w: %w", ErrInvalidAddress, err) } header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort) header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort) } else if header.TransportProtocol.IsIPv6() { var addr _addr6 if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil { return nil, fmt.Errorf("%w: %w", ErrInvalidAddress, err) } header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort) header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort) } else if header.TransportProtocol.IsUnix() { var addr _addrUnix if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil { return nil, fmt.Errorf("%w: %w", ErrInvalidAddress, err) } network := "unix" if header.TransportProtocol.IsDatagram() { network = "unixgram" } header.SourceAddr = &net.UnixAddr{ Net: network, Name: parseUnixName(addr.Src[:]), } header.DestinationAddr = &net.UnixAddr{ Net: network, Name: parseUnixName(addr.Dst[:]), } } } // Copy bytes for optional Type-Length-Value vector header.rawTLVs = make([]byte, payloadReader.N) // Allocate minimum size slice if _, err = io.ReadFull(payloadReader, header.rawTLVs); err != nil && err != io.EOF { return nil, err } if payloadReader.N != 0 { return nil, ErrInvalidLength } return header, nil } func (header *Header) formatVersion2() ([]byte, error) { var buf bytes.Buffer buf.Write(SIGV2) buf.WriteByte(header.Command.toByte()) buf.WriteByte(header.TransportProtocol.toByte()) if header.TransportProtocol.IsUnspec() { // For UNSPEC, write no addresses and ports but only TLVs if they are present hdrLen, err := addTLVLen(lengthUnspecBytes, len(header.rawTLVs)) if err != nil { return nil, err } buf.Write(hdrLen) } else { var addrSrc, addrDst []byte if header.TransportProtocol.IsIPv4() { hdrLen, err := addTLVLen(lengthV4Bytes, len(header.rawTLVs)) if err != nil { return nil, err } buf.Write(hdrLen) sourceIP, destIP, _ := header.IPs() addrSrc = sourceIP.To4() addrDst = destIP.To4() } else if header.TransportProtocol.IsIPv6() { hdrLen, err := addTLVLen(lengthV6Bytes, len(header.rawTLVs)) if err != nil { return nil, err } buf.Write(hdrLen) sourceIP, destIP, _ := header.IPs() addrSrc = sourceIP.To16() addrDst = destIP.To16() } else if header.TransportProtocol.IsUnix() { buf.Write(lengthUnixBytes) sourceAddr, destAddr, ok := header.UnixAddrs() if !ok { return nil, ErrInvalidAddress } addrSrc = formatUnixName(sourceAddr.Name) addrDst = formatUnixName(destAddr.Name) } if addrSrc == nil || addrDst == nil { return nil, ErrInvalidAddress } buf.Write(addrSrc) buf.Write(addrDst) if sourcePort, destPort, ok := header.Ports(); ok { if sourcePort < 0 || sourcePort > math.MaxUint16 || destPort < 0 || destPort > math.MaxUint16 { return nil, ErrInvalidPortNumber } portBytes := make([]byte, 2) //nolint:gosec // Bounds are checked above. binary.BigEndian.PutUint16(portBytes, uint16(sourcePort)) buf.Write(portBytes) //nolint:gosec // Bounds are checked above. binary.BigEndian.PutUint16(portBytes, uint16(destPort)) buf.Write(portBytes) } } if len(header.rawTLVs) > 0 { buf.Write(header.rawTLVs) } return buf.Bytes(), nil } func (header *Header) validateLength(length uint16) bool { if header.TransportProtocol.IsIPv4() { return length >= lengthV4 } else if header.TransportProtocol.IsIPv6() { return length >= lengthV6 } else if header.TransportProtocol.IsUnix() { return length >= lengthUnix } else if header.TransportProtocol.IsUnspec() { return length >= lengthUnspec } return false } // addTLVLen adds the length of the TLV to the header length or errors on uint16 overflow. func addTLVLen(cur []byte, tlvLen int) ([]byte, error) { if tlvLen == 0 { return cur, nil } curLen := binary.BigEndian.Uint16(cur) newLen := int(curLen) + tlvLen if newLen >= 1<<16 { return nil, errUint16Overflow } a := make([]byte, 2) //nolint:gosec // newLen bounds are validated above. binary.BigEndian.PutUint16(a, uint16(newLen)) return a, nil } func newIPAddr(transport AddressFamilyAndProtocol, ip net.IP, port uint16) net.Addr { if transport.IsStream() { return &net.TCPAddr{IP: ip, Port: int(port)} } if transport.IsDatagram() { return &net.UDPAddr{IP: ip, Port: int(port)} } return nil } func parseUnixName(b []byte) string { before, _, ok := bytes.Cut(b, []byte{0}) if !ok { return string(b) } return string(before) } func formatUnixName(name string) []byte { n := int(lengthUnix) / 2 if len(name) >= n { return []byte(name[:n]) } pad := make([]byte, n-len(name)) return append([]byte(name), pad...) } pires-go-proxyproto-04c9ad1/v2_test.go000066400000000000000000000424701514137054000200350ustar00rootroot00000000000000package proxyproto import ( "bufio" "bytes" iorand "crypto/rand" "encoding/binary" "errors" "reflect" "strings" "testing" ) var ( invalidRune = byte('\x99') // Lengths to use in tests. lengthPadded = uint16(84) lengthEmptyBytes = func() []byte { a := make([]byte, 2) binary.BigEndian.PutUint16(a, 0) return a }() lengthPaddedBytes = func() []byte { a := make([]byte, 2) binary.BigEndian.PutUint16(a, lengthPadded) return a }() // If life gives you lemons, make mojitos. portBytes = func() []byte { a := make([]byte, 2) binary.BigEndian.PutUint16(a, testValidPort) return a }() unixBytes = pad([]byte("socket"), 108) // Tests don't care if source and destination addresses and ports are the same. addressesIPv4 = append(v4ip.To4(), v4ip.To4()...) addressesIPv6 = append(v6ip.To16(), v6ip.To16()...) ports = append(portBytes, portBytes...) // Fixtures to use in tests. fixtureIPv4Address = append(addressesIPv4, ports...) fixtureIPv4V2 = append(lengthV4Bytes, fixtureIPv4Address...) fixtureIPv4V2Padded = append(append(lengthPaddedBytes, fixtureIPv4Address...), make([]byte, lengthPadded-lengthV4)...) fixtureIPv6Address = append(addressesIPv6, ports...) fixtureIPv6V2 = append(lengthV6Bytes, fixtureIPv6Address...) fixtureIPv6V2Padded = append(append(lengthPaddedBytes, fixtureIPv6Address...), make([]byte, lengthPadded-lengthV6)...) fixtureUnixAddress = append(unixBytes, unixBytes...) fixtureUnixV2 = append(lengthUnixBytes, fixtureUnixAddress...) fixtureTLV = func() []byte { tlv := make([]byte, 100) _, _ = iorand.Read(tlv) return tlv }() fixtureIPv4V2TLV = fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixtureTLV) fixtureIPv6V2TLV = fixtureWithTLV(lengthV6Bytes, fixtureIPv6Address, fixtureTLV) fixtureUnspecTLV = fixtureWithTLV(lengthUnspecBytes, []byte{}, fixtureTLV) fixtureMediumTLV = make([]byte, 2048) fixtureV2MediumTLV = fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixtureMediumTLV) fixtureTooLargeTLV = make([]byte, 10*1024) fixtureV2TooLargeTLV = fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixtureTooLargeTLV) // Arbitrary bytes following proxy bytes. arbitraryTailBytes = []byte{'\x99', '\x97', '\x98'} ) func pad(b []byte, n int) []byte { padding := make([]byte, n-len(b)) return append(b, padding...) } var invalidParseV2Tests = []struct { desc string reader *bufio.Reader expectedError error }{ { desc: "no signature", reader: newBufioReader([]byte(testNoProtocol)), expectedError: ErrNoProxyProtocol, }, { desc: "truncated v2 signature", reader: newBufioReader(SIGV2[2:]), expectedError: ErrNoProxyProtocol, }, { desc: "v2 signature and nothing else", reader: newBufioReader(SIGV2), expectedError: ErrCantReadProtocolVersionAndCommand, }, { desc: "v2 signature with invalid command", reader: newBufioReader(append(SIGV2, invalidRune)), expectedError: ErrUnsupportedProtocolVersionAndCommand, }, { desc: "v2 signature with command but nothing else", reader: newBufioReader(append(SIGV2, byte(PROXY))), expectedError: ErrCantReadAddressFamilyAndProtocol, }, { desc: "command proxy but inet family unspec", reader: newBufioReader(append(SIGV2, byte(PROXY), byte(UNSPEC))), expectedError: ErrUnsupportedAddressFamilyAndProtocol, }, { desc: "v2 signature with command and invalid inet family", // translated to UNSPEC reader: newBufioReader(append(SIGV2, byte(PROXY), invalidRune)), expectedError: ErrCantReadLength, }, { desc: "TCPv4 but no length", reader: newBufioReader(append(SIGV2, byte(PROXY), byte(TCPv4))), expectedError: ErrCantReadLength, }, { desc: "TCPv4 but invalid length", reader: newBufioReader(append(SIGV2, byte(PROXY), byte(TCPv4), invalidRune)), expectedError: ErrCantReadLength, }, { desc: "unspec but no length", reader: newBufioReader(append(SIGV2, byte(LOCAL), byte(UNSPEC))), expectedError: ErrCantReadLength, }, { desc: "TCPv4 with mismatching length", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), lengthV4Bytes...)), expectedError: ErrInvalidAddress, }, { desc: "TCPv6 with mismatching length", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv6)), lengthV6Bytes...)), expectedError: ErrInvalidAddress, }, { desc: "TCPv4 length zero but with address and ports", reader: newBufioReader(append(append(append(SIGV2, byte(PROXY), byte(TCPv4)), lengthEmptyBytes...), fixtureIPv6Address...)), expectedError: ErrInvalidLength, }, { desc: "TCPv6 with IPv6 length but IPv4 address and ports", reader: newBufioReader(append(append(append(SIGV2, byte(PROXY), byte(TCPv6)), lengthV6Bytes...), fixtureIPv4Address...)), expectedError: ErrInvalidAddress, }, { desc: "unspec length greater than zero but no TLVs", reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(UNSPEC)), fixtureUnspecTLV[:2]...)), expectedError: ErrInvalidLength, }, { desc: "TCPv4 with too large TLV", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureV2TooLargeTLV...)), expectedError: ErrInvalidLength, }, } func TestParseV2Invalid(t *testing.T) { for _, tt := range invalidParseV2Tests { t.Run(tt.desc, func(t *testing.T) { if _, err := Read(tt.reader); !errors.Is(err, tt.expectedError) { t.Fatalf("expected %v, actual %v", tt.expectedError, err) } }) } } var validParseAndWriteV2Tests = []struct { desc string reader *bufio.Reader expectedHeader *Header }{ { desc: "local", reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(TCPv4)), fixtureIPv4V2...)), expectedHeader: &Header{ Version: 2, Command: LOCAL, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, }, }, { desc: "local unspec", reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(UNSPEC)), lengthUnspecBytes...)), expectedHeader: &Header{ Version: 2, Command: LOCAL, TransportProtocol: UNSPEC, SourceAddr: nil, DestinationAddr: nil, }, }, { desc: "proxy TCPv4", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureIPv4V2...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, }, }, { desc: "proxy TCPv6", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv6)), fixtureIPv6V2...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: v6addr, DestinationAddr: v6addr, }, }, { desc: "proxy TCPv4 with TLV", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureIPv4V2TLV...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, rawTLVs: fixtureTLV, }, }, { desc: "proxy TCPv6 with TLV", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv6)), fixtureIPv6V2TLV...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: v6addr, DestinationAddr: v6addr, rawTLVs: fixtureTLV, }, }, { desc: "local unspec with TLV", reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(UNSPEC)), fixtureUnspecTLV...)), expectedHeader: &Header{ Version: 2, Command: LOCAL, TransportProtocol: UNSPEC, SourceAddr: nil, DestinationAddr: nil, rawTLVs: fixtureTLV, }, }, { desc: "proxy UDPv4", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(UDPv4)), fixtureIPv4V2...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv4, SourceAddr: v4UDPAddr, DestinationAddr: v4UDPAddr, }, }, { desc: "proxy UDPv6", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(UDPv6)), fixtureIPv6V2...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv6, SourceAddr: v6UDPAddr, DestinationAddr: v6UDPAddr, }, }, { desc: "proxy unix stream", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(UnixStream)), fixtureUnixV2...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixStream, SourceAddr: unixStreamAddr, DestinationAddr: unixStreamAddr, }, }, { desc: "proxy unix datagram", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(UnixDatagram)), fixtureUnixV2...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixDatagram, SourceAddr: unixDatagramAddr, DestinationAddr: unixDatagramAddr, }, }, { desc: "proxy TCPv4 with medium TLV", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureV2MediumTLV...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, rawTLVs: fixtureMediumTLV, }, }, } func TestParseV2Valid(t *testing.T) { for _, tt := range validParseAndWriteV2Tests { t.Run(tt.desc, func(t *testing.T) { header, err := Read(tt.reader) if err != nil { t.Fatal("unexpected error", err.Error()) } if !header.EqualsTo(tt.expectedHeader) { t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, header) } }) } } func TestWriteV2Valid(t *testing.T) { for _, tt := range validParseAndWriteV2Tests { t.Run(tt.desc, func(t *testing.T) { var b bytes.Buffer w := bufio.NewWriter(&b) if _, err := tt.expectedHeader.WriteTo(w); err != nil { t.Fatal("unexpected error ", err) } if err := w.Flush(); err != nil { t.Fatal("unexpected error ", err) } // Read written bytes to validate written header r := bufio.NewReaderSize(&b, readBufferSize) newHeader, err := Read(r) if err != nil { t.Fatal("unexpected error ", err) } if !newHeader.EqualsTo(tt.expectedHeader) { t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, newHeader) } }) } } var validParseV2PaddedTests = []struct { desc string value []byte expectedHeader *Header }{ { desc: "proxy TCPv4", value: append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureIPv4V2Padded...), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, rawTLVs: make([]byte, lengthPadded-lengthV4), }, }, { desc: "proxy TCPv6", value: append(append(SIGV2, byte(PROXY), byte(TCPv6)), fixtureIPv6V2Padded...), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: v6addr, DestinationAddr: v6addr, rawTLVs: make([]byte, lengthPadded-lengthV6), }, }, { desc: "proxy UDPv4", value: append(append(SIGV2, byte(PROXY), byte(UDPv4)), fixtureIPv4V2Padded...), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv4, SourceAddr: v4addr, DestinationAddr: v4addr, rawTLVs: make([]byte, lengthPadded-lengthV4), }, }, { desc: "proxy UDPv6", value: append(append(SIGV2, byte(PROXY), byte(UDPv6)), fixtureIPv6V2Padded...), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv6, SourceAddr: v6addr, DestinationAddr: v6addr, rawTLVs: make([]byte, lengthPadded-lengthV6), }, }, } func TestParseV2Padded(t *testing.T) { for _, tt := range validParseV2PaddedTests { t.Run(tt.desc, func(t *testing.T) { reader := newBufioReader(append(tt.value, arbitraryTailBytes...)) newHeader, err := Read(reader) if err != nil { t.Fatal("unexpected error ", err) } if !newHeader.EqualsTo(tt.expectedHeader) { t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, newHeader) } // Check that remaining padding bytes have been flushed nextBytes, err := reader.Peek(len(arbitraryTailBytes)) if err != nil { t.Fatal("unexpected error ", err) } if !reflect.DeepEqual(nextBytes, arbitraryTailBytes) { t.Fatalf("expected %#v, actual %#v", arbitraryTailBytes, nextBytes) } }) } } func TestV2EqualsToTLV(t *testing.T) { eHdr := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, } hdr, err := Read(newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureIPv4V2TLV...))) if err != nil { t.Fatal("unexpected error ", err) } if eHdr.EqualsTo(hdr) { t.Fatalf("unexpectedly equal created: %#v, parsed: %#v", eHdr, hdr) } eHdr.rawTLVs = fixtureTLV[:] if !eHdr.EqualsTo(hdr) { t.Fatalf("unexpectedly unequal after tlv copy created: %#v, parsed: %#v", eHdr, hdr) } eHdr.rawTLVs[0] = eHdr.rawTLVs[0] + 1 if eHdr.EqualsTo(hdr) { t.Fatalf("unexpectedly equal after changing tlv created: %#v, parsed: %#v", eHdr, hdr) } } var tlvFormatTests = []struct { desc string header *Header }{ { desc: "proxy TCPv4", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, rawTLVs: make([]byte, 1<<16), }, }, { desc: "proxy TCPv6", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: v6addr, DestinationAddr: v6addr, rawTLVs: make([]byte, 1<<16), }, }, { desc: "proxy UDPv4", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv4, SourceAddr: v4addr, DestinationAddr: v4addr, rawTLVs: make([]byte, 1<<16), }, }, { desc: "proxy UDPv6", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv6, SourceAddr: v6addr, DestinationAddr: v6addr, rawTLVs: make([]byte, 1<<16), }, }, { desc: "local unspec", header: &Header{ Version: 2, Command: LOCAL, TransportProtocol: UNSPEC, SourceAddr: nil, DestinationAddr: nil, rawTLVs: make([]byte, 1<<16), }, }, } func TestV2TLVFormatTooLargeTLV(t *testing.T) { for _, tt := range tlvFormatTests { t.Run(tt.desc, func(t *testing.T) { if _, err := tt.header.Format(); err != errUint16Overflow { t.Fatalf("missing or expected error when formatting too-large TLV %#v", err) } }) } } func newBufioReader(b []byte) *bufio.Reader { return bufio.NewReaderSize(bytes.NewReader(b), readBufferSize) } func fixtureWithTLV(cur []byte, addr []byte, tlv []byte) []byte { tlen, err := addTLVLen(cur, len(tlv)) if err != nil { panic(err) } return append(append(tlen, addr...), tlv...) } func Test_parseUnixName(t *testing.T) { tests := []struct { name string // description of this test case // Named input parameters for target function. b []byte want string }{ { name: "simple name, no null terminator", b: []byte("socketname"), want: "socketname", }, { name: "simple name with single null byte", b: append([]byte("socketname"), 0), want: "socketname", }, { name: "long name with null terminator in the middle", b: append([]byte("sock\000etname"), 0), want: "sock", }, { name: "empty input", b: []byte{}, want: "", }, { name: "all null bytes", b: []byte{0, 0, 0}, want: "", }, { name: "mixed bytes with null at end", b: append([]byte("abc123"), 0), want: "abc123", }, { name: "name with null in middle", b: []byte{'t', 'e', 0, 's', 't'}, want: "te", }, { name: "no null, binary data", b: []byte{0x7f, 0xfe, 0x3c}, want: string([]byte{0x7f, 0xfe, 0x3c}), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := parseUnixName(tt.b) if got != tt.want { t.Errorf("parseUnixName() = %v, want %v", got, tt.want) } }) } } func Test_formatUnixName(t *testing.T) { maxLen := int(lengthUnix) / 2 longName := strings.Repeat("a", maxLen+5) shortName := "socket" longFormatted := formatUnixName(longName) if len(longFormatted) != maxLen { t.Fatalf("formatUnixName() length = %d, want %d", len(longFormatted), maxLen) } if got := parseUnixName(longFormatted); got != longName[:maxLen] { t.Errorf("formatUnixName() long parse = %q, want %q", got, longName[:maxLen]) } shortFormatted := formatUnixName(shortName) if len(shortFormatted) != maxLen { t.Fatalf("formatUnixName() length = %d, want %d", len(shortFormatted), maxLen) } if got := parseUnixName(shortFormatted); got != shortName { t.Errorf("formatUnixName() short parse = %q, want %q", got, shortName) } if !bytes.HasPrefix(shortFormatted, []byte(shortName)) { t.Errorf("formatUnixName() short prefix = %q, want prefix %q", shortFormatted, shortName) } } pires-go-proxyproto-04c9ad1/version_cmd.go000066400000000000000000000031051514137054000207470ustar00rootroot00000000000000package proxyproto // ProtocolVersionAndCommand represents the command in proxy protocol v2. // Command doesn't exist in v1 but it should be set since other parts of // this library may rely on it for determining connection details. type ProtocolVersionAndCommand byte const ( // LOCAL represents the LOCAL command in v2 or UNKNOWN transport in v1, // in which case no address information is expected. LOCAL ProtocolVersionAndCommand = '\x20' // PROXY represents the PROXY command in v2 or transport is not UNKNOWN in v1, // in which case valid local/remote address and port information is expected. PROXY ProtocolVersionAndCommand = '\x21' ) var supportedCommand = map[ProtocolVersionAndCommand]bool{ LOCAL: true, PROXY: true, } // IsLocal returns true if the command in v2 is LOCAL or the transport in v1 is UNKNOWN, // i.e. when no address information is expected, false otherwise. func (pvc ProtocolVersionAndCommand) IsLocal() bool { return LOCAL == pvc } // IsProxy returns true if the command in v2 is PROXY or the transport in v1 is not UNKNOWN, // i.e. when valid local/remote address and port information is expected, false otherwise. func (pvc ProtocolVersionAndCommand) IsProxy() bool { return PROXY == pvc } // IsUnspec returns true if the command is unspecified, false otherwise. func (pvc ProtocolVersionAndCommand) IsUnspec() bool { // Must be LOCAL or PROXY. return !pvc.IsLocal() && !pvc.IsProxy() } func (pvc ProtocolVersionAndCommand) toByte() byte { if pvc.IsLocal() { return byte(LOCAL) } else if pvc.IsProxy() { return byte(PROXY) } return byte(LOCAL) } pires-go-proxyproto-04c9ad1/version_cmd_test.go000066400000000000000000000013511514137054000220070ustar00rootroot00000000000000package proxyproto import ( "testing" ) func TestLocal(t *testing.T) { b := byte(LOCAL) if ProtocolVersionAndCommand(b).IsUnspec() { t.Fail() } if !ProtocolVersionAndCommand(b).IsLocal() { t.Fail() } if ProtocolVersionAndCommand(b).IsProxy() { t.Fail() } if ProtocolVersionAndCommand(b).toByte() != b { t.Fail() } } func TestProxy(t *testing.T) { b := byte(PROXY) if ProtocolVersionAndCommand(b).IsUnspec() { t.Fail() } if ProtocolVersionAndCommand(b).IsLocal() { t.Fail() } if !ProtocolVersionAndCommand(b).IsProxy() { t.Fail() } if ProtocolVersionAndCommand(b).toByte() != b { t.Fail() } } func TestInvalidProtocolVersion(t *testing.T) { if !ProtocolVersionAndCommand(0x00).IsUnspec() { t.Fail() } }