pax_global_header00006660000000000000000000000064150705730170014517gustar00rootroot0000000000000052 comment=c6152507b591e8b5cc6097a2fa9e569a05dd20e1 golang-github-pion-transport-v3-3.0.8/000077500000000000000000000000001507057301700176015ustar00rootroot00000000000000golang-github-pion-transport-v3-3.0.8/.gitignore000066400000000000000000000006321507057301700215720ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT ### JetBrains IDE ### ##################### .idea/ ### Emacs Temporary Files ### ############################# *~ ### Folders ### ############### bin/ vendor/ node_modules/ ### Files ### ############# *.ivf *.ogg tags cover.out *.sw[poe] *.wasm examples/sfu-ws/cert.pem examples/sfu-ws/key.pem wasm_exec.js golang-github-pion-transport-v3-3.0.8/.golangci.yml000066400000000000000000000202661507057301700221730ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT version: "2" linters: enable: - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - bidichk # Checks for dangerous unicode character sequences - bodyclose # checks whether HTTP response body is closed successfully - containedctx # containedctx is a linter that detects struct contained context.Context field - contextcheck # check the function whether use a non-inherited context - cyclop # checks function and package cyclomatic complexity - decorder # check declaration order and count of types, constants, variables and functions - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - dupl # Tool for code clone detection - durationcheck # check for two durations multiplied together - err113 # Golang linter to check the errors handling expressions - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - exhaustive # check exhaustiveness of enum switch statements - forbidigo # Forbids identifiers - forcetypeassert # finds forced type assertions - gochecknoglobals # Checks that no globals are present in Go code - gocognit # Computes and checks the cognitive complexity of functions - goconst # Finds repeated strings that could be replaced by a constant - gocritic # The most opinionated Go source code linter - gocyclo # Computes and checks the cyclomatic complexity of functions - godot # Check if comments end in a period - godox # Tool for detection of FIXME, TODO and other comment keywords - goheader # Checks is file header matches to pattern - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - goprintffuncname # Checks that printf-like functions are named with `f` at the end - gosec # Inspects source code for security problems - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - grouper # An analyzer to analyze expression groups. - importas # Enforces consistent import aliases - ineffassign # Detects when assignments to existing variables are not used - lll # Reports long lines - maintidx # maintidx measures the maintainability index of each function. - makezero # Finds slice declarations with non-zero initial length - misspell # Finds commonly misspelled English words in comments - nakedret # Finds naked returns in functions greater than a specified function length - nestif # Reports deeply nested if statements - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity - noctx # noctx finds sending http request without context.Context - predeclared # find code that shadows one of Go's predeclared identifiers - revive # golint replacement, finds style mistakes - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks - tagliatelle # Checks the struct tags. - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - unconvert # Remove unnecessary type conversions - unparam # Reports unused function parameters - unused # Checks Go code for unused constants, variables, functions and types - varnamelen # checks that the length of a variable's name matches its scope - wastedassign # wastedassign finds wasted assignment statements - whitespace # Tool for detection of leading and trailing whitespace disable: - depguard # Go linter that checks if package imports are in a list of acceptable packages - funlen # Tool for detection of long functions - gochecknoinits # Checks that no init functions are present in Go code - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - interfacebloat # A linter that checks length of interface. - ireturn # Accept Interfaces, Return Concrete Types - mnd # An analyzer to detect magic numbers - nolintlint # Reports ill-formed or insufficient nolint directives - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test - prealloc # Finds slice declarations that could potentially be preallocated - promlinter # Check Prometheus metrics naming via promlint - rowserrcheck # checks whether Err of rows is checked successfully - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - testpackage # linter that makes you use a separate _test package - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - wrapcheck # Checks that errors returned from external packages are wrapped - wsl # Whitespace Linter - Forces you to use empty lines! settings: staticcheck: checks: - all - -QF1008 # "could remove embedded field", to keep it explicit! - -QF1003 # "could use tagged switch on enum", Cases conflicts with exhaustive! exhaustive: default-signifies-exhaustive: true forbidigo: forbid: - pattern: ^fmt.Print(f|ln)?$ - pattern: ^log.(Panic|Fatal|Print)(f|ln)?$ - pattern: ^os.Exit$ - pattern: ^panic$ - pattern: ^print(ln)?$ - pattern: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ pkg: ^testing$ msg: use testify/assert instead analyze-types: true gomodguard: blocked: modules: - github.com/pkg/errors: recommendations: - errors govet: enable: - shadow revive: rules: # Prefer 'any' type alias over 'interface{}' for Go 1.18+ compatibility - name: use-any severity: warning disabled: false misspell: locale: US varnamelen: max-distance: 12 min-name-length: 2 ignore-type-assert-ok: true ignore-map-index-ok: true ignore-chan-recv-ok: true ignore-decls: - i int - n int - w io.Writer - r io.Reader - b []byte exclusions: generated: lax rules: - linters: - forbidigo - gocognit path: (examples|main\.go) - linters: - gocognit path: _test\.go - linters: - forbidigo path: cmd formatters: enable: - gci # Gci control golang package import order and make it always deterministic. - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - gofumpt # Gofumpt checks whether code was gofumpt-ed. - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports exclusions: generated: lax golang-github-pion-transport-v3-3.0.8/.goreleaser.yml000066400000000000000000000001711507057301700225310ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT builds: - skip: true golang-github-pion-transport-v3-3.0.8/.reuse/000077500000000000000000000000001507057301700210025ustar00rootroot00000000000000golang-github-pion-transport-v3-3.0.8/.reuse/dep5000066400000000000000000000011141507057301700215570ustar00rootroot00000000000000Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ Upstream-Name: Pion Source: https://github.com/pion/ Files: README.md DESIGN.md **/README.md AUTHORS.txt renovate.json go.mod go.sum **/go.mod **/go.sum .eslintrc.json package.json examples.json sfu-ws/flutter/.gitignore sfu-ws/flutter/pubspec.yaml c-data-channels/webrtc.h examples/examples.json yarn.lock Copyright: 2023 The Pion community License: MIT Files: testdata/seed/* testdata/fuzz/* **/testdata/fuzz/* api/*.txt Copyright: 2023 The Pion community License: CC0-1.0 golang-github-pion-transport-v3-3.0.8/LICENSE000066400000000000000000000021051507057301700206040ustar00rootroot00000000000000MIT License Copyright (c) 2023 The Pion community Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. golang-github-pion-transport-v3-3.0.8/LICENSES/000077500000000000000000000000001507057301700210065ustar00rootroot00000000000000golang-github-pion-transport-v3-3.0.8/LICENSES/BSD-3-Clause.txt000066400000000000000000000026641507057301700235410ustar00rootroot00000000000000Copyright (c) . Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. golang-github-pion-transport-v3-3.0.8/LICENSES/MIT.txt000066400000000000000000000020661507057301700222040ustar00rootroot00000000000000MIT License Copyright (c) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. golang-github-pion-transport-v3-3.0.8/README.md000066400000000000000000000043201507057301700210570ustar00rootroot00000000000000


Pion Transport

Transport testing for Pion

Pion transport join us on Discord Follow us on Bluesky
GitHub Workflow Status Go Reference Coverage Status Go Report Card License: MIT


### Roadmap The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones. ### Community Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt). Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news. We are always looking to support **your projects**. Please reach out if you have something to build! If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) ### Contributing Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible ### License MIT License - see [LICENSE](LICENSE) for full text golang-github-pion-transport-v3-3.0.8/codecov.yml000066400000000000000000000007151507057301700217510ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT coverage: status: project: default: # Allow decreasing 2% of total coverage to avoid noise. threshold: 2% patch: default: target: 70% only_pulls: true ignore: - "examples/*" - "examples/**/*" golang-github-pion-transport-v3-3.0.8/connctx/000077500000000000000000000000001507057301700212555ustar00rootroot00000000000000golang-github-pion-transport-v3-3.0.8/connctx/connctx.go000066400000000000000000000065541507057301700232720ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package connctx wraps net.Conn using context.Context. // // Deprecated: use netctx instead. package connctx import ( "context" "errors" "io" "net" "sync" "sync/atomic" "time" ) // ErrClosing is returned on Write to closed connection. var ErrClosing = errors.New("use of closed network connection") // Reader is an interface for context controlled reader. type Reader interface { ReadContext(context.Context, []byte) (int, error) } // Writer is an interface for context controlled writer. type Writer interface { WriteContext(context.Context, []byte) (int, error) } // ReadWriter is a composite of ReadWriter. type ReadWriter interface { Reader Writer } // ConnCtx is a wrapper of net.Conn using context.Context. type ConnCtx interface { Reader Writer io.Closer LocalAddr() net.Addr RemoteAddr() net.Addr Conn() net.Conn } type connCtx struct { nextConn net.Conn closed chan struct{} closeOnce sync.Once readMu sync.Mutex writeMu sync.Mutex } var veryOld = time.Unix(0, 1) //nolint:gochecknoglobals // New creates a new ConnCtx wrapping given net.Conn. func New(conn net.Conn) ConnCtx { c := &connCtx{ nextConn: conn, closed: make(chan struct{}), } return c } func (c *connCtx) ReadContext(ctx context.Context, b []byte) (int, error) { //nolint:cyclop c.readMu.Lock() defer c.readMu.Unlock() select { case <-c.closed: return 0, io.EOF default: } done := make(chan struct{}) var wg sync.WaitGroup var errSetDeadline atomic.Value wg.Add(1) go func() { defer wg.Done() select { case <-ctx.Done(): // context canceled if err := c.nextConn.SetReadDeadline(veryOld); err != nil { errSetDeadline.Store(err) return } <-done if err := c.nextConn.SetReadDeadline(time.Time{}); err != nil { errSetDeadline.Store(err) } case <-done: } }() n, err := c.nextConn.Read(b) close(done) wg.Wait() if e := ctx.Err(); e != nil && n == 0 { err = e } if err2, ok := errSetDeadline.Load().(error); ok && err == nil && err2 != nil { err = err2 } return n, err } func (c *connCtx) WriteContext(ctx context.Context, b []byte) (int, error) { //nolint:cyclop c.writeMu.Lock() defer c.writeMu.Unlock() select { case <-c.closed: return 0, ErrClosing default: } done := make(chan struct{}) var wg sync.WaitGroup var errSetDeadline atomic.Value wg.Add(1) go func() { defer wg.Done() select { case <-ctx.Done(): // context canceled if err := c.nextConn.SetWriteDeadline(veryOld); err != nil { errSetDeadline.Store(err) return } <-done if err := c.nextConn.SetWriteDeadline(time.Time{}); err != nil { errSetDeadline.Store(err) } case <-done: } }() n, err := c.nextConn.Write(b) close(done) wg.Wait() if e := ctx.Err(); e != nil && n == 0 { err = e } if err2, ok := errSetDeadline.Load().(error); ok && err == nil && err2 != nil { err = err2 } return n, err } func (c *connCtx) Close() error { err := c.nextConn.Close() c.closeOnce.Do(func() { c.writeMu.Lock() c.readMu.Lock() close(c.closed) c.readMu.Unlock() c.writeMu.Unlock() }) return err } func (c *connCtx) LocalAddr() net.Addr { return c.nextConn.LocalAddr() } func (c *connCtx) RemoteAddr() net.Addr { return c.nextConn.RemoteAddr() } func (c *connCtx) Conn() net.Conn { return c.nextConn } golang-github-pion-transport-v3-3.0.8/connctx/connctx_test.go000066400000000000000000000136311507057301700243230ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package connctx import ( "context" "errors" "io" "net" "testing" "time" "github.com/stretchr/testify/assert" ) func TestRead(t *testing.T) { ca, cb := net.Pipe() defer func() { _ = ca.Close() }() data := []byte{0x01, 0x02, 0xFF} chErr := make(chan error) go func() { _, err := cb.Write(data) chErr <- err }() c := New(ca) b := make([]byte, 100) n, err := c.ReadContext(context.Background(), b) assert.NoError(t, err) assert.Len(t, data, n) assert.Equal(t, data, b[:n]) err = <-chErr assert.NoError(t, err) } func TestReadTimeout(t *testing.T) { ca, _ := net.Pipe() defer func() { _ = ca.Close() }() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() c := New(ca) b := make([]byte, 100) n, err := c.ReadContext(ctx, b) assert.Error(t, err) assert.Empty(t, n) } func TestReadCancel(t *testing.T) { ca, _ := net.Pipe() defer func() { _ = ca.Close() }() ctx, cancel := context.WithCancel(context.Background()) go func() { time.Sleep(10 * time.Millisecond) cancel() }() c := New(ca) b := make([]byte, 100) n, err := c.ReadContext(ctx, b) assert.Error(t, err) assert.Empty(t, n) } func TestReadClosed(t *testing.T) { ca, _ := net.Pipe() c := New(ca) _ = c.Close() b := make([]byte, 100) n, err := c.ReadContext(context.Background(), b) assert.ErrorIs(t, err, io.EOF) assert.Empty(t, n) } func TestWrite(t *testing.T) { ca, cb := net.Pipe() defer func() { _ = ca.Close() }() chErr := make(chan error) chRead := make(chan []byte) go func() { b := make([]byte, 100) n, err := cb.Read(b) chErr <- err chRead <- b[:n] }() c := New(ca) data := []byte{0x01, 0x02, 0xFF} n, err := c.WriteContext(context.Background(), data) assert.NoError(t, err) assert.Len(t, data, n) err = <-chErr b := <-chRead assert.Equal(t, b, data) assert.NoError(t, err) } func TestWriteTimeout(t *testing.T) { ca, _ := net.Pipe() defer func() { _ = ca.Close() }() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() c := New(ca) b := make([]byte, 100) n, err := c.WriteContext(ctx, b) assert.Error(t, err) assert.Empty(t, n) } func TestWriteCancel(t *testing.T) { ca, _ := net.Pipe() defer func() { _ = ca.Close() }() ctx, cancel := context.WithCancel(context.Background()) go func() { time.Sleep(10 * time.Millisecond) cancel() }() c := New(ca) b := make([]byte, 100) n, err := c.WriteContext(ctx, b) assert.Error(t, err) assert.Empty(t, n) } func TestWriteClosed(t *testing.T) { ca, _ := net.Pipe() c := New(ca) _ = c.Close() b := make([]byte, 100) n, err := c.WriteContext(context.Background(), b) assert.ErrorIs(t, err, ErrClosing) assert.Empty(t, n) } // Test for TestLocalAddrAndRemoteAddr. type stringAddr struct { network string addr string } func (a stringAddr) Network() string { return a.network } func (a stringAddr) String() string { return a.addr } type connAddrMock struct{} func (*connAddrMock) RemoteAddr() net.Addr { return stringAddr{"remote_net", "remote_addr"} } func (*connAddrMock) LocalAddr() net.Addr { return stringAddr{"local_net", "local_addr"} } func (*connAddrMock) Read(_ []byte) (n int, err error) { panic("unimplemented") //nolint } func (*connAddrMock) Write(_ []byte) (n int, err error) { panic("unimplemented") //nolint } func (*connAddrMock) Close() error { panic("unimplemented") //nolint } func (*connAddrMock) SetDeadline(_ time.Time) error { panic("unimplemented") //nolint } func (*connAddrMock) SetReadDeadline(_ time.Time) error { panic("unimplemented") //nolint } func (*connAddrMock) SetWriteDeadline(_ time.Time) error { panic("unimplemented") //nolint } func TestLocalAddrAndRemoteAddr(t *testing.T) { c := New(&connAddrMock{}) al := c.LocalAddr() ar := c.RemoteAddr() assert.Equal(t, "local_addr", al.String()) assert.Equal(t, "remote_addr", ar.String()) } func BenchmarkBase(b *testing.B) { ca, cb := net.Pipe() defer func() { _ = ca.Close() }() data := make([]byte, 4096) for i := range data { data[i] = byte(i) } buf := make([]byte, len(data)) b.SetBytes(int64(len(data))) b.ResetTimer() go func(n int) { for i := 0; i < n; i++ { _, _ = cb.Write(data) } _ = cb.Close() }(b.N) count := 0 for { n, err := ca.Read(buf) if err != nil { if !errors.Is(err, io.EOF) { b.Fatal(err) } break } if n != len(data) { b.Errorf("Expected %v, got %v", len(data), n) } count++ } if count != b.N { b.Errorf("Expected %v, got %v", b.N, count) } } func BenchmarkWrite(b *testing.B) { ca, cb := net.Pipe() defer func() { _ = ca.Close() }() data := make([]byte, 4096) for i := range data { data[i] = byte(i) } buf := make([]byte, len(data)) b.SetBytes(int64(len(data))) b.ResetTimer() go func(n int) { c := New(cb) for i := 0; i < n; i++ { _, _ = c.WriteContext(context.Background(), data) } _ = cb.Close() }(b.N) count := 0 for { n, err := ca.Read(buf) if err != nil { if !errors.Is(err, io.EOF) { b.Fatal(err) } break } if n != len(data) { b.Errorf("Expected %v, got %v", len(data), n) } count++ } if count != b.N { b.Errorf("Expected %v, got %v", b.N, count) } } func BenchmarkRead(b *testing.B) { ca, cb := net.Pipe() defer func() { _ = ca.Close() }() data := make([]byte, 4096) for i := range data { data[i] = byte(i) } buf := make([]byte, len(data)) b.SetBytes(int64(len(data))) b.ResetTimer() go func(n int) { for i := 0; i < n; i++ { _, _ = cb.Write(data) } _ = cb.Close() }(b.N) c := New(ca) count := 0 for { n, err := c.ReadContext(context.Background(), buf) if err != nil { if !errors.Is(err, io.EOF) { b.Fatal(err) } break } if n != len(data) { b.Errorf("Expected %v, got %v", len(data), n) } count++ } if count != b.N { b.Errorf("Expected %v, got %v", b.N, count) } } golang-github-pion-transport-v3-3.0.8/connctx/pipe.go000066400000000000000000000004041507057301700225370ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package connctx import ( "net" ) // Pipe creates piped pair of ConnCtx. func Pipe() (ConnCtx, ConnCtx) { ca, cb := net.Pipe() return New(ca), New(cb) } golang-github-pion-transport-v3-3.0.8/deadline/000077500000000000000000000000001507057301700213465ustar00rootroot00000000000000golang-github-pion-transport-v3-3.0.8/deadline/deadline.go000066400000000000000000000043151507057301700234450ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package deadline provides deadline timer used to implement // net.Conn compatible connection package deadline import ( "context" "sync" "time" ) type deadlineState uint8 const ( deadlineStopped deadlineState = iota deadlineStarted deadlineExceeded ) var _ context.Context = (*Deadline)(nil) // Deadline signals updatable deadline timer. // Also, it implements context.Context. type Deadline struct { mu sync.RWMutex timer timer done chan struct{} deadline time.Time state deadlineState pending uint8 } // New creates new deadline timer. func New() *Deadline { return &Deadline{ done: make(chan struct{}), } } func (d *Deadline) timeout() { d.mu.Lock() if d.pending--; d.pending != 0 || d.state != deadlineStarted { d.mu.Unlock() return } d.state = deadlineExceeded done := d.done d.mu.Unlock() close(done) } // Set new deadline. Zero value means no deadline. func (d *Deadline) Set(setTo time.Time) { d.mu.Lock() defer d.mu.Unlock() if d.state == deadlineStarted && d.timer.Stop() { d.pending-- } d.deadline = setTo d.pending++ if d.state == deadlineExceeded { d.done = make(chan struct{}) } if setTo.IsZero() { d.pending-- d.state = deadlineStopped return } if dur := time.Until(setTo); dur > 0 { d.state = deadlineStarted if d.timer == nil { d.timer = afterFunc(dur, d.timeout) } else { d.timer.Reset(dur) } return } d.pending-- d.state = deadlineExceeded close(d.done) } // Done receives deadline signal. func (d *Deadline) Done() <-chan struct{} { d.mu.RLock() defer d.mu.RUnlock() return d.done } // Err returns context.DeadlineExceeded if the deadline is exceeded. // Otherwise, it returns nil. func (d *Deadline) Err() error { d.mu.RLock() defer d.mu.RUnlock() if d.state == deadlineExceeded { return context.DeadlineExceeded } return nil } // Deadline returns current deadline. func (d *Deadline) Deadline() (time.Time, bool) { d.mu.RLock() defer d.mu.RUnlock() if d.deadline.IsZero() { return d.deadline, false } return d.deadline, true } // Value returns nil. func (d *Deadline) Value(any) any { return nil } golang-github-pion-transport-v3-3.0.8/deadline/deadline_test.go000066400000000000000000000103511507057301700245010ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package deadline import ( "context" "testing" "time" "github.com/stretchr/testify/assert" ) func TestDeadline(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() t.Run("Deadline", func(t *testing.T) { now := time.Now() ctx0, cancel0 := context.WithDeadline(ctx, now.Add(40*time.Millisecond)) defer cancel0() ctx1, cancel1 := context.WithDeadline(ctx, now.Add(60*time.Millisecond)) defer cancel1() d := New() d.Set(now.Add(50 * time.Millisecond)) ch := make(chan byte) go sendOnDone(ctx, ctx0.Done(), ch, 0) go sendOnDone(ctx, ctx1.Done(), ch, 1) go sendOnDone(ctx, d.Done(), ch, 2) calls := collectCh(ch, 3, 100*time.Millisecond) expectedCalls := []byte{0, 2, 1} assert.Equal(t, expectedCalls, calls, "Wrong order of deadline signal") }) t.Run("DeadlineExtend", func(t *testing.T) { //nolint:dupl now := time.Now() ctx0, cancel0 := context.WithDeadline(ctx, now.Add(40*time.Millisecond)) defer cancel0() ctx1, cancel1 := context.WithDeadline(ctx, now.Add(60*time.Millisecond)) defer cancel1() d := New() d.Set(now.Add(50 * time.Millisecond)) d.Set(now.Add(70 * time.Millisecond)) ch := make(chan byte) go sendOnDone(ctx, ctx0.Done(), ch, 0) go sendOnDone(ctx, ctx1.Done(), ch, 1) go sendOnDone(ctx, d.Done(), ch, 2) calls := collectCh(ch, 3, 100*time.Millisecond) expectedCalls := []byte{0, 1, 2} assert.Equal(t, expectedCalls, calls, "Wrong order of deadline signal") }) t.Run("DeadlinePretend", func(t *testing.T) { //nolint:dupl now := time.Now() ctx0, cancel0 := context.WithDeadline(ctx, now.Add(40*time.Millisecond)) defer cancel0() ctx1, cancel1 := context.WithDeadline(ctx, now.Add(60*time.Millisecond)) defer cancel1() d := New() d.Set(now.Add(50 * time.Millisecond)) d.Set(now.Add(30 * time.Millisecond)) ch := make(chan byte) go sendOnDone(ctx, ctx0.Done(), ch, 0) go sendOnDone(ctx, ctx1.Done(), ch, 1) go sendOnDone(ctx, d.Done(), ch, 2) calls := collectCh(ch, 3, 100*time.Millisecond) expectedCalls := []byte{2, 0, 1} assert.Equal(t, expectedCalls, calls, "Wrong order of deadline signal") }) t.Run("DeadlineCancel", func(t *testing.T) { now := time.Now() ctx0, cancel0 := context.WithDeadline(ctx, now.Add(40*time.Millisecond)) defer cancel0() d := New() d.Set(now.Add(50 * time.Millisecond)) d.Set(time.Time{}) ch := make(chan byte) go sendOnDone(ctx, ctx0.Done(), ch, 0) go sendOnDone(ctx, d.Done(), ch, 1) calls := collectCh(ch, 2, 60*time.Millisecond) expectedCalls := []byte{0} assert.Equal(t, expectedCalls, calls, "Wrong order of deadline signal") }) } func sendOnDone(ctx context.Context, done <-chan struct{}, dest chan byte, val byte) { select { case <-done: case <-ctx.Done(): return } dest <- val } func collectCh(ch <-chan byte, n int, timeout time.Duration) []byte { a := time.After(timeout) var calls []byte for len(calls) < n { select { case call := <-ch: calls = append(calls, call) case <-a: return calls } } return calls } func TestContext(t *testing.T) { //nolint:cyclop t.Run("Cancel", func(t *testing.T) { deadline := New() select { case <-deadline.Done(): assert.Fail(t, "Deadline unexpectedly done") case <-time.After(50 * time.Millisecond): } assert.NoError(t, deadline.Err()) deadline.Set(time.Unix(0, 1)) // exceeded select { case <-deadline.Done(): case <-time.After(50 * time.Millisecond): assert.Fail(t, "Timeout") } assert.ErrorIs(t, deadline.Err(), context.DeadlineExceeded) }) t.Run("Deadline", func(t *testing.T) { d := New() t0, expired0 := d.Deadline() assert.True(t, t0.IsZero(), "Initial Deadline is expected to be 0") assert.False(t, expired0, "Deadline is not expected to be expired at initial state") dl := time.Unix(12345, 0) d.Set(dl) // exceeded t1, expired1 := d.Deadline() assert.True(t, t1.Equal(dl), "Initial Deadline is expected to be %v, got %v", dl, t1) assert.True(t, expired1, "Deadline is expected to be expired") }) } func BenchmarkDeadline(b *testing.B) { b.Run("Set", func(b *testing.B) { d := New() t := time.Now().Add(time.Minute) for i := 0; i < b.N; i++ { d.Set(t) } }) } golang-github-pion-transport-v3-3.0.8/deadline/timer.go000066400000000000000000000003151507057301700230140ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package deadline import ( "time" ) type timer interface { Stop() bool Reset(time.Duration) bool } golang-github-pion-transport-v3-3.0.8/deadline/timer_generic.go000066400000000000000000000003731507057301700245140ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package deadline import ( "time" ) func afterFunc(d time.Duration, f func()) timer { return time.AfterFunc(d, f) } golang-github-pion-transport-v3-3.0.8/deadline/timer_js.go000066400000000000000000000017371507057301700235210ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build js // +build js package deadline import ( "sync" "time" ) // jsTimer is a timer utility for wasm with a working Reset function. type jsTimer struct { f func() mu sync.Mutex timer *time.Timer version uint64 started bool } func afterFunc(d time.Duration, f func()) timer { t := &jsTimer{f: f} t.Reset(d) return t } func (t *jsTimer) Stop() bool { t.mu.Lock() defer t.mu.Unlock() t.version++ t.timer.Stop() started := t.started t.started = false return started } func (t *jsTimer) Reset(d time.Duration) bool { t.mu.Lock() defer t.mu.Unlock() if t.timer != nil { t.timer.Stop() } t.version++ version := t.version t.timer = time.AfterFunc(d, func() { t.mu.Lock() if version != t.version { t.mu.Unlock() return } t.started = false t.mu.Unlock() t.f() }) started := t.started t.started = true return started } golang-github-pion-transport-v3-3.0.8/dpipe/000077500000000000000000000000001507057301700207025ustar00rootroot00000000000000golang-github-pion-transport-v3-3.0.8/dpipe/dpipe.go000066400000000000000000000057231507057301700223410ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package dpipe provides the pipe works like datagram protocol on memory. // // This package is mainly intended for testing and not for production! package dpipe import ( "context" "io" "net" "sync" "time" "github.com/pion/transport/v3/deadline" ) // Pipe creates pair of non-stream conn on memory. // Close of the one end doesn't make effect to the other end. func Pipe() (net.Conn, net.Conn) { ch0 := make(chan []byte, 1000) ch1 := make(chan []byte, 1000) return &conn{ rCh: ch0, wCh: ch1, closed: make(chan struct{}), closing: make(chan struct{}), readDeadline: deadline.New(), writeDeadline: deadline.New(), }, &conn{ rCh: ch1, wCh: ch0, closed: make(chan struct{}), closing: make(chan struct{}), readDeadline: deadline.New(), writeDeadline: deadline.New(), } } type pipeAddr struct{} func (pipeAddr) Network() string { return "pipe" } func (pipeAddr) String() string { return ":1" } type conn struct { rCh chan []byte wCh chan []byte closed chan struct{} closing chan struct{} closeOnce sync.Once readDeadline *deadline.Deadline writeDeadline *deadline.Deadline } func (*conn) LocalAddr() net.Addr { return pipeAddr{} } func (*conn) RemoteAddr() net.Addr { return pipeAddr{} } func (c *conn) SetDeadline(t time.Time) error { c.readDeadline.Set(t) c.writeDeadline.Set(t) return nil } func (c *conn) SetReadDeadline(t time.Time) error { c.readDeadline.Set(t) return nil } func (c *conn) SetWriteDeadline(t time.Time) error { c.writeDeadline.Set(t) return nil } func (c *conn) Read(data []byte) (n int, err error) { //nolint:cyclop select { case <-c.closed: return 0, io.EOF case <-c.closing: if len(c.rCh) == 0 { return 0, io.EOF } case <-c.readDeadline.Done(): return 0, context.DeadlineExceeded default: } for { select { case d := <-c.rCh: if len(d) <= len(data) { copy(data, d) return len(d), nil } copy(data, d[:len(data)]) return len(data), nil case <-c.closed: return 0, io.EOF case <-c.closing: if len(c.rCh) == 0 { return 0, io.EOF } case <-c.readDeadline.Done(): return 0, context.DeadlineExceeded } } } func (c *conn) cleanWriteBuffer() { for { select { case <-c.wCh: default: return } } } func (c *conn) Write(data []byte) (n int, err error) { select { case <-c.closed: return 0, io.ErrClosedPipe case <-c.writeDeadline.Done(): c.cleanWriteBuffer() return 0, context.DeadlineExceeded default: } cData := make([]byte, len(data)) copy(cData, data) select { case <-c.closed: return 0, io.ErrClosedPipe case <-c.writeDeadline.Done(): c.cleanWriteBuffer() return 0, context.DeadlineExceeded case c.wCh <- cData: return len(cData), nil } } func (c *conn) Close() error { c.closeOnce.Do(func() { close(c.closed) }) return nil } golang-github-pion-transport-v3-3.0.8/dpipe/dpipe_test.go000066400000000000000000000043361507057301700233770ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package dpipe import ( "fmt" "io" "net" "testing" "time" "github.com/stretchr/testify/assert" "golang.org/x/net/nettest" ) var errFailedToCast = fmt.Errorf("failed to cast net.Conn to conn") func TestNetTest(t *testing.T) { nettest.TestConn(t, func() (net.Conn, net.Conn, func(), error) { ca, cb := Pipe() caConn, ok := ca.(*conn) if !ok { return nil, nil, nil, errFailedToCast } cbConn, ok := cb.(*conn) if !ok { return nil, nil, nil, errFailedToCast } return &closePropagator{caConn, cbConn}, &closePropagator{cbConn, caConn}, func() { _ = ca.Close() _ = cb.Close() }, nil }) } type closePropagator struct { *conn otherEnd *conn } func (c *closePropagator) Close() error { close(c.otherEnd.closing) return c.conn.Close() } func TestPipe(t *testing.T) { //nolint:cyclop ca, cb := Pipe() testData := []byte{0x01, 0x02} for name, cond := range map[string]struct { ca net.Conn cb net.Conn }{ "AtoB": {ca, cb}, "BtoA": {cb, ca}, } { c0 := cond.ca c1 := cond.cb t.Run(name, func(t *testing.T) { n, err := c0.Write(testData) assert.NoError(t, err) assert.Equal(t, len(testData), n) readData := make([]byte, 4) n, err = c1.Read(readData) assert.NoError(t, err) assert.Len(t, testData, n) assert.Equal(t, testData, readData[:n]) }) } assert.NoError(t, ca.Close()) _, err := ca.Write(testData) assert.ErrorIs(t, err, io.ErrClosedPipe, "Write to closed conn should fail") // Other side should be writable. _, err = cb.Write(testData) assert.NoError(t, err) readData := make([]byte, 4) _, err = ca.Read(readData) assert.ErrorIs(t, err, io.EOF, "Read from closed conn should fail with io.EOF") // Other side should be readable. readDone := make(chan struct{}) go func() { readData := make([]byte, 4) n, err := cb.Read(readData) assert.Errorf(t, err, "Unexpected data %v was arrived to orphaned conn", readData[:n]) close(readDone) }() select { case <-readDone: assert.Fail(t, "Read should be blocked if the other side is closed") case <-time.After(10 * time.Millisecond): } assert.NoError(t, cb.Close()) } golang-github-pion-transport-v3-3.0.8/examples/000077500000000000000000000000001507057301700214175ustar00rootroot00000000000000golang-github-pion-transport-v3-3.0.8/examples/vnet-udpproxy/000077500000000000000000000000001507057301700242635ustar00rootroot00000000000000golang-github-pion-transport-v3-3.0.8/examples/vnet-udpproxy/README.md000066400000000000000000000022251507057301700255430ustar00rootroot00000000000000# vnet-udpproxy This example demonstrates how VNet can be used to communicate with non-VNet addresses using UDPProxy. In this example we listen map the VNet Address `10.0.0.11` to a real address of our choice. We then send to our real address from three different VNet addresses. If you pass `-address 192.168.1.3:8000` the traffic will be the following ``` vnet(10.0.0.11:5787) => proxy => 192.168.1.3:8000 vnet(10.0.0.11:5788) => proxy => 192.168.1.3:8000 vnet(10.0.0.11:5789) => proxy => 192.168.1.3:8000 ``` ## Running ``` go run main.go -address 192.168.1.3:8000 ``` You should see the following in tcpdump ``` sean@SeanLaptop:~/go/src/github.com/pion/transport/examples$ sudo tcpdump -i any udp and port 8000 tcpdump: data link type LINUX_SLL2 tcpdump: verbose output suppressed, use -v[v]... for full protocol decode listening on any, link-type LINUX_SLL2 (Linux cooked v2), snapshot length 262144 bytes 13:21:18.239943 lo In IP 192.168.1.7.40574 > 192.168.1.7.8000: UDP, length 5 13:21:18.240105 lo In IP 192.168.1.7.40647 > 192.168.1.7.8000: UDP, length 5 13:21:18.240304 lo In IP 192.168.1.7.57744 > 192.168.1.7.8000: UDP, length 5 ``` golang-github-pion-transport-v3-3.0.8/examples/vnet-udpproxy/main.go000066400000000000000000000043431507057301700255420ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package main implements an example for the virtual Net // UDP proxy. package main import ( "flag" "net" "time" "github.com/pion/logging" "github.com/pion/transport/v3/vnet" ) func main() { //nolint:cyclop address := flag.String("address", "", "Destination address that three separate vnet clients will send too") flag.Parse() // Create vnet WAN with one endpoint // See the following docs for more information // https://github.com/pion/transport/tree/master/vnet#example-wan-with-one-endpoint-vnet router, err := vnet.NewRouter(&vnet.RouterConfig{ CIDR: "0.0.0.0/0", LoggerFactory: logging.NewDefaultLoggerFactory(), }) if err != nil { panic(err) } // Create a network and add to router, for example, for client. clientNetwork, err := vnet.NewNet(&vnet.NetConfig{ StaticIP: "10.0.0.11", }) if err != nil { panic(err) } if err = router.AddNet(clientNetwork); err != nil { panic(err) } if err = router.Start(); err != nil { panic(err) } defer router.Stop() // nolint:errcheck // Create a proxy, bind to the router. proxy, err := vnet.NewProxy(router) if err != nil { panic(err) } defer proxy.Close() // nolint:errcheck serverAddr, err := net.ResolveUDPAddr("udp4", *address) if err != nil { panic(err) } // Start to proxy some addresses, clientNetwork is a hit for proxy, // that the client in vnet is from this network. if err = proxy.Proxy(clientNetwork, serverAddr); err != nil { panic(err) } // Now, all packets from client, will be proxy to real server, vice versa. client0, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5787") if err != nil { panic(err) } _, _ = client0.WriteTo([]byte("Hello"), serverAddr) client1, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5788") if err != nil { panic(err) } _, _ = client1.WriteTo([]byte("Hello"), serverAddr) client2, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5789") if err != nil { panic(err) } _, _ = client2.WriteTo([]byte("Hello"), serverAddr) // Packets are delivered by a goroutine so WriteTo // return doesn't mean success. This may improve in // the future. time.Sleep(time.Second * 3) } golang-github-pion-transport-v3-3.0.8/go.mod000066400000000000000000000005451507057301700207130ustar00rootroot00000000000000module github.com/pion/transport/v3 go 1.21 require ( github.com/pion/logging v0.2.4 github.com/stretchr/testify v1.11.1 github.com/wlynxg/anet v0.0.5 golang.org/x/net v0.34.0 golang.org/x/sys v0.29.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) golang-github-pion-transport-v3-3.0.8/go.sum000066400000000000000000000027551507057301700207450ustar00rootroot00000000000000github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU= github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= golang-github-pion-transport-v3-3.0.8/net.go000066400000000000000000000373221507057301700207250ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package transport implements various networking related // functions used throughout the Pion modules. package transport import ( "errors" "io" "net" "time" ) var ( // ErrNoAddressAssigned ... ErrNoAddressAssigned = errors.New("no address assigned") // ErrNotSupported ... ErrNotSupported = errors.New("not supported yey") // ErrInterfaceNotFound ... ErrInterfaceNotFound = errors.New("interface not found") // ErrNotUDPAddress ... ErrNotUDPAddress = errors.New("not a UDP address") ) // Net is an interface providing common networking functions which are // similar to the functions provided by standard net package. type Net interface { // ListenPacket announces on the local network address. // // The network must be "udp", "udp4", "udp6", "unixgram", or an IP // transport. The IP transports are "ip", "ip4", or "ip6" followed by // a colon and a literal protocol number or a protocol name, as in // "ip:1" or "ip:icmp". // // For UDP and IP networks, if the host in the address parameter is // empty or a literal unspecified IP address, ListenPacket listens on // all available IP addresses of the local system except multicast IP // addresses. // To only use IPv4, use network "udp4" or "ip4:proto". // The address can use a host name, but this is not recommended, // because it will create a listener for at most one of the host's IP // addresses. // If the port in the address parameter is empty or "0", as in // "127.0.0.1:" or "[::1]:0", a port number is automatically chosen. // The LocalAddr method of PacketConn can be used to discover the // chosen port. // // See func Dial for a description of the network and address // parameters. // // ListenPacket uses context.Background internally; to specify the context, use // ListenConfig.ListenPacket. ListenPacket(network string, address string) (net.PacketConn, error) // ListenUDP acts like ListenPacket for UDP networks. // // The network must be a UDP network name; see func Dial for details. // // If the IP field of laddr is nil or an unspecified IP address, // ListenUDP listens on all available IP addresses of the local system // except multicast IP addresses. // If the Port field of laddr is 0, a port number is automatically // chosen. ListenUDP(network string, locAddr *net.UDPAddr) (UDPConn, error) // ListenTCP acts like Listen for TCP networks. // // The network must be a TCP network name; see func Dial for details. // // If the IP field of laddr is nil or an unspecified IP address, // ListenTCP listens on all available unicast and anycast IP addresses // of the local system. // If the Port field of laddr is 0, a port number is automatically // chosen. ListenTCP(network string, laddr *net.TCPAddr) (TCPListener, error) // Dial connects to the address on the named network. // // Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only), // "udp", "udp4" (IPv4-only), "udp6" (IPv6-only), "ip", "ip4" // (IPv4-only), "ip6" (IPv6-only), "unix", "unixgram" and // "unixpacket". // // For TCP and UDP networks, the address has the form "host:port". // The host must be a literal IP address, or a host name that can be // resolved to IP addresses. // The port must be a literal port number or a service name. // If the host is a literal IPv6 address it must be enclosed in square // brackets, as in "[2001:db8::1]:80" or "[fe80::1%zone]:80". // The zone specifies the scope of the literal IPv6 address as defined // in RFC 4007. // The functions JoinHostPort and SplitHostPort manipulate a pair of // host and port in this form. // When using TCP, and the host resolves to multiple IP addresses, // Dial will try each IP address in order until one succeeds. // // Examples: // // Dial("tcp", "golang.org:http") // Dial("tcp", "192.0.2.1:http") // Dial("tcp", "198.51.100.1:80") // Dial("udp", "[2001:db8::1]:domain") // Dial("udp", "[fe80::1%lo0]:53") // Dial("tcp", ":80") // // For IP networks, the network must be "ip", "ip4" or "ip6" followed // by a colon and a literal protocol number or a protocol name, and // the address has the form "host". The host must be a literal IP // address or a literal IPv6 address with zone. // It depends on each operating system how the operating system // behaves with a non-well known protocol number such as "0" or "255". // // Examples: // // Dial("ip4:1", "192.0.2.1") // Dial("ip6:ipv6-icmp", "2001:db8::1") // Dial("ip6:58", "fe80::1%lo0") // // For TCP, UDP and IP networks, if the host is empty or a literal // unspecified IP address, as in ":80", "0.0.0.0:80" or "[::]:80" for // TCP and UDP, "", "0.0.0.0" or "::" for IP, the local system is // assumed. // // For Unix networks, the address must be a file system path. Dial(network, address string) (net.Conn, error) // DialUDP acts like Dial for UDP networks. // // The network must be a UDP network name; see func Dial for details. // // If laddr is nil, a local address is automatically chosen. // If the IP field of raddr is nil or an unspecified IP address, the // local system is assumed. DialUDP(network string, laddr, raddr *net.UDPAddr) (UDPConn, error) // DialTCP acts like Dial for TCP networks. // // The network must be a TCP network name; see func Dial for details. // // If laddr is nil, a local address is automatically chosen. // If the IP field of raddr is nil or an unspecified IP address, the // local system is assumed. DialTCP(network string, laddr, raddr *net.TCPAddr) (TCPConn, error) // ResolveIPAddr returns an address of IP end point. // // The network must be an IP network name. // // If the host in the address parameter is not a literal IP address, // ResolveIPAddr resolves the address to an address of IP end point. // Otherwise, it parses the address as a literal IP address. // The address parameter can use a host name, but this is not // recommended, because it will return at most one of the host name's // IP addresses. // // See func Dial for a description of the network and address // parameters. ResolveIPAddr(network, address string) (*net.IPAddr, error) // ResolveUDPAddr returns an address of UDP end point. // // The network must be a UDP network name. // // If the host in the address parameter is not a literal IP address or // the port is not a literal port number, ResolveUDPAddr resolves the // address to an address of UDP end point. // Otherwise, it parses the address as a pair of literal IP address // and port number. // The address parameter can use a host name, but this is not // recommended, because it will return at most one of the host name's // IP addresses. // // See func Dial for a description of the network and address // parameters. ResolveUDPAddr(network, address string) (*net.UDPAddr, error) // ResolveTCPAddr returns an address of TCP end point. // // The network must be a TCP network name. // // If the host in the address parameter is not a literal IP address or // the port is not a literal port number, ResolveTCPAddr resolves the // address to an address of TCP end point. // Otherwise, it parses the address as a pair of literal IP address // and port number. // The address parameter can use a host name, but this is not // recommended, because it will return at most one of the host name's // IP addresses. // // See func Dial for a description of the network and address // parameters. ResolveTCPAddr(network, address string) (*net.TCPAddr, error) // Interfaces returns a list of the system's network interfaces. Interfaces() ([]*Interface, error) // InterfaceByIndex returns the interface specified by index. // // On Solaris, it returns one of the logical network interfaces // sharing the logical data link; for more precision use // InterfaceByName. InterfaceByIndex(index int) (*Interface, error) // InterfaceByName returns the interface specified by name. InterfaceByName(name string) (*Interface, error) // The following functions are extensions to Go's standard net package CreateDialer(dialer *net.Dialer) Dialer } // Dialer is identical to net.Dialer excepts that its methods // (Dial, DialContext) are overridden to use the Net interface. // Use vnet.CreateDialer() to create an instance of this Dialer. type Dialer interface { Dial(network, address string) (net.Conn, error) } // UDPConn is packet-oriented connection for UDP. type UDPConn interface { // Close closes the connection. // Any blocked Read or Write operations will be unblocked and return errors. Close() error // LocalAddr returns the local network address, if known. LocalAddr() net.Addr // RemoteAddr returns the remote network address, if known. RemoteAddr() net.Addr // SetDeadline sets the read and write deadlines associated // with the connection. It is equivalent to calling both // SetReadDeadline and SetWriteDeadline. // // A deadline is an absolute time after which I/O operations // fail instead of blocking. The deadline applies to all future // and pending I/O, not just the immediately following call to // Read or Write. After a deadline has been exceeded, the // connection can be refreshed by setting a deadline in the future. // // If the deadline is exceeded a call to Read or Write or to other // I/O methods will return an error that wraps os.ErrDeadlineExceeded. // This can be tested using errors.Is(err, os.ErrDeadlineExceeded). // The error's Timeout method will return true, but note that there // are other possible errors for which the Timeout method will // return true even if the deadline has not been exceeded. // // An idle timeout can be implemented by repeatedly extending // the deadline after successful Read or Write calls. // // A zero value for t means I/O operations will not time out. SetDeadline(t time.Time) error // SetReadDeadline sets the deadline for future Read calls // and any currently-blocked Read call. // A zero value for t means Read will not time out. SetReadDeadline(t time.Time) error // SetWriteDeadline sets the deadline for future Write calls // and any currently-blocked Write call. // Even if write times out, it may return n > 0, indicating that // some of the data was successfully written. // A zero value for t means Write will not time out. SetWriteDeadline(t time.Time) error // SetReadBuffer sets the size of the operating system's // receive buffer associated with the connection. SetReadBuffer(bytes int) error // SetWriteBuffer sets the size of the operating system's // transmit buffer associated with the connection. SetWriteBuffer(bytes int) error // Read reads data from the connection. // Read can be made to time out and return an error after a fixed // time limit; see SetDeadline and SetReadDeadline. Read(b []byte) (n int, err error) // ReadFrom reads a packet from the connection, // copying the payload into p. It returns the number of // bytes copied into p and the return address that // was on the packet. // It returns the number of bytes read (0 <= n <= len(p)) // and any error encountered. Callers should always process // the n > 0 bytes returned before considering the error err. // ReadFrom can be made to time out and return an error after a // fixed time limit; see SetDeadline and SetReadDeadline. ReadFrom(p []byte) (n int, addr net.Addr, err error) // ReadFromUDP acts like ReadFrom but returns a UDPAddr. ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) // ReadMsgUDP reads a message from c, copying the payload into b and // the associated out-of-band data into oob. It returns the number of // bytes copied into b, the number of bytes copied into oob, the flags // that were set on the message and the source address of the message. // // The packages golang.org/x/net/ipv4 and golang.org/x/net/ipv6 can be // used to manipulate IP-level socket options in oob. ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) // Write writes data to the connection. // Write can be made to time out and return an error after a fixed // time limit; see SetDeadline and SetWriteDeadline. Write(b []byte) (n int, err error) // WriteTo writes a packet with payload p to addr. // WriteTo can be made to time out and return an Error after a // fixed time limit; see SetDeadline and SetWriteDeadline. // On packet-oriented connections, write timeouts are rare. WriteTo(p []byte, addr net.Addr) (n int, err error) // WriteToUDP acts like WriteTo but takes a UDPAddr. WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) // WriteMsgUDP writes a message to addr via c if c isn't connected, or // to c's remote address if c is connected (in which case addr must be // nil). The payload is copied from b and the associated out-of-band // data is copied from oob. It returns the number of payload and // out-of-band bytes written. // // The packages golang.org/x/net/ipv4 and golang.org/x/net/ipv6 can be // used to manipulate IP-level socket options in oob. WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) } // TCPConn is an interface for TCP network connections. type TCPConn interface { net.Conn // CloseRead shuts down the reading side of the TCP connection. // Most callers should just use Close. CloseRead() error // CloseWrite shuts down the writing side of the TCP connection. // Most callers should just use Close. CloseWrite() error // ReadFrom implements the io.ReaderFrom ReadFrom method. ReadFrom(r io.Reader) (int64, error) // SetLinger sets the behavior of Close on a connection which still // has data waiting to be sent or to be acknowledged. // // If sec < 0 (the default), the operating system finishes sending the // data in the background. // // If sec == 0, the operating system discards any unsent or // unacknowledged data. // // If sec > 0, the data is sent in the background as with sec < 0. On // some operating systems after sec seconds have elapsed any remaining // unsent data may be discarded. SetLinger(sec int) error // SetKeepAlive sets whether the operating system should send // keep-alive messages on the connection. SetKeepAlive(keepalive bool) error // SetKeepAlivePeriod sets period between keep-alives. SetKeepAlivePeriod(d time.Duration) error // SetNoDelay controls whether the operating system should delay // packet transmission in hopes of sending fewer packets (Nagle's // algorithm). The default is true (no delay), meaning that data is // sent as soon as possible after a Write. SetNoDelay(noDelay bool) error // SetWriteBuffer sets the size of the operating system's // transmit buffer associated with the connection. SetWriteBuffer(bytes int) error // SetReadBuffer sets the size of the operating system's // receive buffer associated with the connection. SetReadBuffer(bytes int) error } // TCPListener is a TCP network listener. Clients should typically // use variables of type Listener instead of assuming TCP. type TCPListener interface { net.Listener // AcceptTCP accepts the next incoming call and returns the new // connection. AcceptTCP() (TCPConn, error) // SetDeadline sets the deadline associated with the listener. // A zero time value disables the deadline. SetDeadline(t time.Time) error } // Interface wraps a standard net.Interfaces and its assigned addresses. type Interface struct { net.Interface addrs []net.Addr } // NewInterface creates a new interface based of a standard net.Interface. func NewInterface(ifc net.Interface) *Interface { return &Interface{ Interface: ifc, addrs: nil, } } // AddAddress adds a new address to the interface. func (ifc *Interface) AddAddress(addr net.Addr) { ifc.addrs = append(ifc.addrs, addr) } // Addrs returns a slice of configured addresses on the interface. func (ifc *Interface) Addrs() ([]net.Addr, error) { if len(ifc.addrs) == 0 { return nil, ErrNoAddressAssigned } return ifc.addrs, nil } golang-github-pion-transport-v3-3.0.8/netctx/000077500000000000000000000000001507057301700211065ustar00rootroot00000000000000golang-github-pion-transport-v3-3.0.8/netctx/conn.go000066400000000000000000000075241507057301700224020ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package netctx wraps common net interfaces using context.Context. package netctx import ( "context" "errors" "io" "net" "sync" "sync/atomic" "time" ) // ErrClosing is returned on Write to closed connection. var ErrClosing = errors.New("use of closed network connection") // Reader is an interface for context controlled reader. type Reader interface { ReadContext(context.Context, []byte) (int, error) } // Writer is an interface for context controlled writer. type Writer interface { WriteContext(context.Context, []byte) (int, error) } // ReadWriter is a composite of ReadWriter. type ReadWriter interface { Reader Writer } // Conn is a wrapper of net.Conn using context.Context. type Conn interface { Reader Writer io.Closer LocalAddr() net.Addr RemoteAddr() net.Addr Conn() net.Conn } type conn struct { nextConn net.Conn closed chan struct{} closeOnce sync.Once readMu sync.Mutex writeMu sync.Mutex } var veryOld = time.Unix(0, 1) //nolint:gochecknoglobals // NewConn creates a new Conn wrapping given net.Conn. func NewConn(netConn net.Conn) Conn { c := &conn{ nextConn: netConn, closed: make(chan struct{}), } return c } // ReadContext reads data from the connection. // Unlike net.Conn.Read(), the provided context is used to control timeout. func (c *conn) ReadContext(ctx context.Context, b []byte) (int, error) { //nolint:cyclop c.readMu.Lock() defer c.readMu.Unlock() select { case <-c.closed: return 0, net.ErrClosed default: } done := make(chan struct{}) var wg sync.WaitGroup var errSetDeadline atomic.Value wg.Add(1) go func() { defer wg.Done() select { case <-ctx.Done(): // context canceled if err := c.nextConn.SetReadDeadline(veryOld); err != nil { errSetDeadline.Store(err) return } <-done if err := c.nextConn.SetReadDeadline(time.Time{}); err != nil { errSetDeadline.Store(err) } case <-done: } }() n, err := c.nextConn.Read(b) close(done) wg.Wait() if e := ctx.Err(); e != nil && n == 0 { err = e } if err2, ok := errSetDeadline.Load().(error); ok && err == nil && err2 != nil { err = err2 } return n, err } // WriteContext writes data to the connection. // Unlike net.Conn.Write(), the provided context is used to control timeout. func (c *conn) WriteContext(ctx context.Context, b []byte) (int, error) { //nolint:cyclop c.writeMu.Lock() defer c.writeMu.Unlock() select { case <-c.closed: return 0, ErrClosing default: } done := make(chan struct{}) var wg sync.WaitGroup var errSetDeadline atomic.Value wg.Add(1) go func() { defer wg.Done() select { case <-ctx.Done(): // context canceled if err := c.nextConn.SetWriteDeadline(veryOld); err != nil { errSetDeadline.Store(err) return } <-done if err := c.nextConn.SetWriteDeadline(time.Time{}); err != nil { errSetDeadline.Store(err) } case <-done: } }() n, err := c.nextConn.Write(b) close(done) wg.Wait() if e := ctx.Err(); e != nil && n == 0 { err = e } if err2, ok := errSetDeadline.Load().(error); ok && err == nil && err2 != nil { err = err2 } return n, err } // Close closes the connection. // Any blocked ReadContext or WriteContext operations will be unblocked and // return errors. func (c *conn) Close() error { err := c.nextConn.Close() c.closeOnce.Do(func() { c.writeMu.Lock() c.readMu.Lock() close(c.closed) c.readMu.Unlock() c.writeMu.Unlock() }) return err } // LocalAddr returns the local network address, if known. func (c *conn) LocalAddr() net.Addr { return c.nextConn.LocalAddr() } // LocalAddr returns the local network address, if known. func (c *conn) RemoteAddr() net.Addr { return c.nextConn.RemoteAddr() } // Conn returns the underlying net.Conn. func (c *conn) Conn() net.Conn { return c.nextConn } golang-github-pion-transport-v3-3.0.8/netctx/conn_test.go000066400000000000000000000137101507057301700234330ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package netctx import ( "context" "errors" "io" "net" "testing" "time" "github.com/stretchr/testify/assert" ) func TestRead(t *testing.T) { ca, cb := net.Pipe() defer func() { _ = ca.Close() }() data := []byte{0x01, 0x02, 0xFF} chErr := make(chan error) go func() { _, err := cb.Write(data) chErr <- err }() c := NewConn(ca) b := make([]byte, 100) n, err := c.ReadContext(context.Background(), b) assert.NoError(t, err) assert.Equal(t, len(data), n) assert.Equal(t, data, b[:n]) assert.NoError(t, <-chErr) } func TestReadTimeout(t *testing.T) { ca, _ := net.Pipe() defer func() { _ = ca.Close() }() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() c := NewConn(ca) b := make([]byte, 100) n, err := c.ReadContext(ctx, b) assert.Error(t, err) assert.Empty(t, n) } func TestReadCancel(t *testing.T) { ca, _ := net.Pipe() defer func() { _ = ca.Close() }() ctx, cancel := context.WithCancel(context.Background()) go func() { time.Sleep(10 * time.Millisecond) cancel() }() c := NewConn(ca) b := make([]byte, 100) n, err := c.ReadContext(ctx, b) assert.Error(t, err) assert.Empty(t, n) } func TestReadClosed(t *testing.T) { ca, _ := net.Pipe() c := NewConn(ca) _ = c.Close() b := make([]byte, 100) n, err := c.ReadContext(context.Background(), b) assert.ErrorIs(t, err, net.ErrClosed) assert.Empty(t, n) } func TestWrite(t *testing.T) { ca, cb := net.Pipe() defer func() { _ = ca.Close() }() chErr := make(chan error) chRead := make(chan []byte) go func() { b := make([]byte, 100) n, err := cb.Read(b) chErr <- err chRead <- b[:n] }() c := NewConn(ca) data := []byte{0x01, 0x02, 0xFF} n, err := c.WriteContext(context.Background(), data) assert.NoError(t, err) assert.Len(t, data, n) err = <-chErr b := <-chRead assert.NoError(t, err) assert.Equal(t, data, b) } func TestWriteTimeout(t *testing.T) { ca, _ := net.Pipe() defer func() { _ = ca.Close() }() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() c := NewConn(ca) b := make([]byte, 100) n, err := c.WriteContext(ctx, b) assert.Error(t, err) assert.Empty(t, n) } func TestWriteCancel(t *testing.T) { ca, _ := net.Pipe() defer func() { _ = ca.Close() }() ctx, cancel := context.WithCancel(context.Background()) go func() { time.Sleep(10 * time.Millisecond) cancel() }() c := NewConn(ca) b := make([]byte, 100) n, err := c.WriteContext(ctx, b) assert.Error(t, err) assert.Empty(t, n) } func TestWriteClosed(t *testing.T) { ca, _ := net.Pipe() c := NewConn(ca) _ = c.Close() b := make([]byte, 100) n, err := c.WriteContext(context.Background(), b) assert.ErrorIs(t, err, ErrClosing) assert.Empty(t, n) } // Test for TestLocalAddrAndRemoteAddr. type stringAddr struct { network string addr string } func (a stringAddr) Network() string { return a.network } func (a stringAddr) String() string { return a.addr } type connAddrMock struct{} func (*connAddrMock) RemoteAddr() net.Addr { return stringAddr{"remote_net", "remote_addr"} } func (*connAddrMock) LocalAddr() net.Addr { return stringAddr{"local_net", "local_addr"} } func (*connAddrMock) Read(_ []byte) (n int, err error) { panic("unimplemented") //nolint } func (*connAddrMock) Write(_ []byte) (n int, err error) { panic("unimplemented") //nolint } func (*connAddrMock) Close() error { panic("unimplemented") //nolint } func (*connAddrMock) SetDeadline(_ time.Time) error { panic("unimplemented") //nolint } func (*connAddrMock) SetReadDeadline(_ time.Time) error { panic("unimplemented") //nolint } func (*connAddrMock) SetWriteDeadline(_ time.Time) error { panic("unimplemented") //nolint } func TestLocalAddrAndRemoteAddr(t *testing.T) { c := NewConn(&connAddrMock{}) al := c.LocalAddr() ar := c.RemoteAddr() assert.Equal(t, "local_addr", al.String()) assert.Equal(t, "remote_addr", ar.String()) } func BenchmarkBase(b *testing.B) { ca, cb := net.Pipe() defer func() { _ = ca.Close() }() data := make([]byte, 4096) for i := range data { data[i] = byte(i) } buf := make([]byte, len(data)) b.SetBytes(int64(len(data))) b.ResetTimer() go func(n int) { for i := 0; i < n; i++ { _, _ = cb.Write(data) } _ = cb.Close() }(b.N) count := 0 for { n, err := ca.Read(buf) if err != nil { if !errors.Is(err, io.EOF) { b.Fatal(err) } break } if n != len(data) { b.Errorf("Expected %v, got %v", len(data), n) } count++ } if count != b.N { b.Errorf("Expected %v, got %v", b.N, count) } } func BenchmarkWrite(b *testing.B) { ca, cb := net.Pipe() defer func() { _ = ca.Close() }() data := make([]byte, 4096) for i := range data { data[i] = byte(i) } buf := make([]byte, len(data)) b.SetBytes(int64(len(data))) b.ResetTimer() go func(n int) { c := NewConn(cb) for i := 0; i < n; i++ { _, _ = c.WriteContext(context.Background(), data) } _ = cb.Close() }(b.N) count := 0 for { n, err := ca.Read(buf) if err != nil { if !errors.Is(err, io.EOF) { b.Fatal(err) } break } if n != len(data) { b.Errorf("Expected %v, got %v", len(data), n) } count++ } if count != b.N { b.Errorf("Expected %v, got %v", b.N, count) } } func BenchmarkRead(b *testing.B) { ca, cb := net.Pipe() defer func() { _ = ca.Close() }() data := make([]byte, 4096) for i := range data { data[i] = byte(i) } buf := make([]byte, len(data)) b.SetBytes(int64(len(data))) b.ResetTimer() go func(n int) { for i := 0; i < n; i++ { _, _ = cb.Write(data) } _ = cb.Close() }(b.N) c := NewConn(ca) count := 0 for { n, err := c.ReadContext(context.Background(), buf) if err != nil { if !errors.Is(err, io.EOF) { b.Fatal(err) } break } if n != len(data) { b.Errorf("Expected %v, got %v", len(data), n) } count++ } if count != b.N { b.Errorf("Expected %v, got %v", b.N, count) } } golang-github-pion-transport-v3-3.0.8/netctx/packetconn.go000066400000000000000000000077601507057301700235740ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package netctx import ( "context" "io" "net" "sync" "sync/atomic" "time" ) // ReaderFrom is an interface for context controlled packet reader. type ReaderFrom interface { ReadFromContext(context.Context, []byte) (int, net.Addr, error) } // WriterTo is an interface for context controlled packet writer. type WriterTo interface { WriteToContext(context.Context, []byte, net.Addr) (int, error) } // PacketConn is a wrapper of net.PacketConn using context.Context. type PacketConn interface { ReaderFrom WriterTo io.Closer LocalAddr() net.Addr Conn() net.PacketConn } type packetConn struct { nextConn net.PacketConn closed chan struct{} closeOnce sync.Once readMu sync.Mutex writeMu sync.Mutex } // NewPacketConn creates a new PacketConn wrapping the given net.PacketConn. func NewPacketConn(pconn net.PacketConn) PacketConn { p := &packetConn{ nextConn: pconn, closed: make(chan struct{}), } return p } // ReadFromContext reads a packet from the connection, // copying the payload into p. It returns the number of // bytes copied into p and the return address that // was on the packet. // It returns the number of bytes read (0 <= n <= len(p)) // and any error encountered. Callers should always process // the n > 0 bytes returned before considering the error err. // Unlike net.PacketConn.ReadFrom(), the provided context is // used to control timeout. func (p *packetConn) ReadFromContext(ctx context.Context, b []byte) (int, net.Addr, error) { //nolint:cyclop p.readMu.Lock() defer p.readMu.Unlock() select { case <-p.closed: return 0, nil, net.ErrClosed default: } done := make(chan struct{}) var wg sync.WaitGroup var errSetDeadline atomic.Value wg.Add(1) go func() { defer wg.Done() select { case <-ctx.Done(): // context canceled if err := p.nextConn.SetReadDeadline(veryOld); err != nil { errSetDeadline.Store(err) return } <-done if err := p.nextConn.SetReadDeadline(time.Time{}); err != nil { errSetDeadline.Store(err) } case <-done: } }() n, raddr, err := p.nextConn.ReadFrom(b) close(done) wg.Wait() if e := ctx.Err(); e != nil && n == 0 { err = e } if err2, ok := errSetDeadline.Load().(error); ok && err == nil && err2 != nil { err = err2 } return n, raddr, err } // WriteToContext writes a packet with payload p to addr. // Unlike net.PacketConn.WriteTo(), the provided context // is used to control timeout. // On packet-oriented connections, write timeouts are rare. func (p *packetConn) WriteToContext(ctx context.Context, b []byte, raddr net.Addr) (int, error) { //nolint:cyclop p.writeMu.Lock() defer p.writeMu.Unlock() select { case <-p.closed: return 0, ErrClosing default: } done := make(chan struct{}) var wg sync.WaitGroup var errSetDeadline atomic.Value wg.Add(1) go func() { defer wg.Done() select { case <-ctx.Done(): // context canceled if err := p.nextConn.SetWriteDeadline(veryOld); err != nil { errSetDeadline.Store(err) return } <-done if err := p.nextConn.SetWriteDeadline(time.Time{}); err != nil { errSetDeadline.Store(err) } case <-done: } }() n, err := p.nextConn.WriteTo(b, raddr) close(done) wg.Wait() if e := ctx.Err(); e != nil && n == 0 { err = e } if err2, ok := errSetDeadline.Load().(error); ok && err == nil && err2 != nil { err = err2 } return n, err } // Close closes the connection. // Any blocked ReadFromContext or WriteToContext operations will be unblocked // and return errors. func (p *packetConn) Close() error { err := p.nextConn.Close() p.closeOnce.Do(func() { p.writeMu.Lock() p.readMu.Lock() close(p.closed) p.readMu.Unlock() p.writeMu.Unlock() }) return err } // LocalAddr returns the local network address, if known. func (p *packetConn) LocalAddr() net.Addr { return p.nextConn.LocalAddr() } // Conn returns the underlying net.PacketConn. func (p *packetConn) Conn() net.PacketConn { return p.nextConn } golang-github-pion-transport-v3-3.0.8/netctx/packetconn_test.go000066400000000000000000000156331507057301700246310ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package netctx import ( "context" "errors" "io" "net" "testing" "time" "github.com/stretchr/testify/assert" ) var _ net.PacketConn = wrapConn{} type wrapConn struct { c net.Conn } func (w wrapConn) ReadFrom(p []byte) (int, net.Addr, error) { n, err := w.c.Read(p) return n, nil, err } func (w wrapConn) WriteTo(p []byte, _ net.Addr) (n int, err error) { return w.c.Write(p) } func (w wrapConn) Close() error { return w.c.Close() } func (w wrapConn) LocalAddr() net.Addr { return w.c.LocalAddr() } func (w wrapConn) RemoteAddr() net.Addr { return w.c.RemoteAddr() } func (w wrapConn) SetDeadline(t time.Time) error { return w.c.SetDeadline(t) } func (w wrapConn) SetReadDeadline(t time.Time) error { return w.c.SetReadDeadline(t) } func (w wrapConn) SetWriteDeadline(t time.Time) error { return w.c.SetWriteDeadline(t) } func pipe() (net.PacketConn, net.PacketConn) { a, b := net.Pipe() return wrapConn{a}, wrapConn{b} } func TestReadFrom(t *testing.T) { ca, cb := pipe() defer func() { _ = ca.Close() }() data := []byte{0x01, 0x02, 0xFF} chErr := make(chan error) go func() { _, err := cb.WriteTo(data, nil) chErr <- err }() c := NewPacketConn(ca) b := make([]byte, 100) n, _, err := c.ReadFromContext(context.Background(), b) assert.NoError(t, err) assert.Len(t, data, n, "Wrong data length") assert.Equal(t, data, b[:n]) assert.NoError(t, <-chErr) } func TestReadFromTimeout(t *testing.T) { ca, _ := pipe() defer func() { _ = ca.Close() }() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() c := NewPacketConn(ca) b := make([]byte, 100) n, _, err := c.ReadFromContext(ctx, b) assert.Error(t, err) assert.Empty(t, n, "Wrong data length") } func TestReadFromCancel(t *testing.T) { ca, _ := pipe() defer func() { _ = ca.Close() }() ctx, cancel := context.WithCancel(context.Background()) go func() { time.Sleep(10 * time.Millisecond) cancel() }() c := NewPacketConn(ca) b := make([]byte, 100) n, _, err := c.ReadFromContext(ctx, b) assert.Error(t, err) assert.Empty(t, n, "Wrong data length") } func TestReadFromClosed(t *testing.T) { ca, _ := pipe() c := NewPacketConn(ca) _ = c.Close() b := make([]byte, 100) n, _, err := c.ReadFromContext(context.Background(), b) assert.ErrorIs(t, err, net.ErrClosed) assert.Empty(t, n, "Wrong data length") } func TestWriteTo(t *testing.T) { ca, cb := pipe() defer func() { _ = ca.Close() }() chErr := make(chan error) chRead := make(chan []byte) go func() { b := make([]byte, 100) n, _, err := cb.ReadFrom(b) chErr <- err chRead <- b[:n] }() c := NewPacketConn(ca) data := []byte{0x01, 0x02, 0xFF} n, err := c.WriteToContext(context.Background(), data, nil) assert.NoError(t, err) assert.Len(t, data, n, "Wrong data length") err = <-chErr b := <-chRead assert.NoError(t, err) assert.Equal(t, data, b) } func TestWriteToTimeout(t *testing.T) { ca, _ := pipe() defer func() { _ = ca.Close() }() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() c := NewPacketConn(ca) b := make([]byte, 100) n, err := c.WriteToContext(ctx, b, nil) assert.Error(t, err) assert.Empty(t, n, "Wrong data length") } func TestWriteToCancel(t *testing.T) { ca, _ := pipe() defer func() { _ = ca.Close() }() ctx, cancel := context.WithCancel(context.Background()) go func() { time.Sleep(10 * time.Millisecond) cancel() }() c := NewPacketConn(ca) b := make([]byte, 100) n, err := c.WriteToContext(ctx, b, nil) assert.Error(t, err) assert.Empty(t, n, "Wrong data length") } func TestWriteToClosed(t *testing.T) { ca, _ := pipe() c := NewPacketConn(ca) _ = c.Close() b := make([]byte, 100) n, err := c.WriteToContext(context.Background(), b, nil) assert.ErrorIs(t, err, ErrClosing) assert.Empty(t, n, "Wrong data length") } type packetConnAddrMock struct{} func (*packetConnAddrMock) LocalAddr() net.Addr { return stringAddr{"local_net", "local_addr"} } func (*packetConnAddrMock) ReadFrom([]byte) (int, net.Addr, error) { panic("unimplemented") } //nolint:forbidigo func (*packetConnAddrMock) WriteTo([]byte, net.Addr) (int, error) { panic("unimplemented") } //nolint:forbidigo func (*packetConnAddrMock) Close() error { panic("unimplemented") } //nolint:forbidigo func (*packetConnAddrMock) SetDeadline(_ time.Time) error { panic("unimplemented") } //nolint:forbidigo func (*packetConnAddrMock) SetReadDeadline(_ time.Time) error { panic("unimplemented") } //nolint:forbidigo func (*packetConnAddrMock) SetWriteDeadline(_ time.Time) error { panic("unimplemented") } //nolint:forbidigo func TestPacketConnLocalAddrAndRemoteAddr(t *testing.T) { c := NewPacketConn(&packetConnAddrMock{}) al := c.LocalAddr() assert.Equal(t, "local_addr", al.String()) } func BenchmarkPacketConnBase(b *testing.B) { ca, cb := pipe() defer func() { _ = ca.Close() }() data := make([]byte, 4096) for i := range data { data[i] = byte(i) } buf := make([]byte, len(data)) b.SetBytes(int64(len(data))) b.ResetTimer() go func(n int) { for i := 0; i < n; i++ { _, _ = cb.WriteTo(data, nil) } _ = cb.Close() }(b.N) count := 0 for { n, _, err := ca.ReadFrom(buf) if err != nil { if !errors.Is(err, io.EOF) { b.Fatal(err) } break } if n != len(data) { b.Errorf("Expected %v, got %v", len(data), n) } count++ } if count != b.N { b.Errorf("Expected %v, got %v", b.N, count) } } func BenchmarkWriteTo(b *testing.B) { ca, cb := pipe() defer func() { _ = ca.Close() }() data := make([]byte, 4096) for i := range data { data[i] = byte(i) } buf := make([]byte, len(data)) b.SetBytes(int64(len(data))) b.ResetTimer() go func(n int) { c := NewPacketConn(cb) for i := 0; i < n; i++ { _, _ = c.WriteToContext(context.Background(), data, nil) } _ = cb.Close() }(b.N) count := 0 for { n, _, err := ca.ReadFrom(buf) if err != nil { if !errors.Is(err, io.EOF) { b.Fatal(err) } break } if n != len(data) { b.Errorf("Expected %v, got %v", len(data), n) } count++ } if count != b.N { b.Errorf("Expected %v, got %v", b.N, count) } } func BenchmarkReadFrom(b *testing.B) { ca, cb := pipe() defer func() { _ = ca.Close() }() data := make([]byte, 4096) for i := range data { data[i] = byte(i) } buf := make([]byte, len(data)) b.SetBytes(int64(len(data))) b.ResetTimer() go func(n int) { for i := 0; i < n; i++ { _, _ = cb.WriteTo(data, nil) } _ = cb.Close() }(b.N) c := NewPacketConn(ca) count := 0 for { n, _, err := c.ReadFromContext(context.Background(), buf) if err != nil { if !errors.Is(err, io.EOF) { b.Fatal(err) } break } if n != len(data) { b.Errorf("Expected %v, got %v", len(data), n) } count++ } if count != b.N { b.Errorf("Expected %v, got %v", b.N, count) } } golang-github-pion-transport-v3-3.0.8/netctx/pipe.go000066400000000000000000000004021507057301700223660ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package netctx import ( "net" ) // Pipe creates piped pair of Conn. func Pipe() (Conn, Conn) { ca, cb := net.Pipe() return NewConn(ca), NewConn(cb) } golang-github-pion-transport-v3-3.0.8/packetio/000077500000000000000000000000001507057301700214005ustar00rootroot00000000000000golang-github-pion-transport-v3-3.0.8/packetio/buffer.go000066400000000000000000000167021507057301700232060ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package packetio provides packet buffer package packetio import ( "errors" "io" "sync" "time" "github.com/pion/transport/v3/deadline" ) var errPacketTooBig = errors.New("packet too big") // BufferPacketType allow the Buffer to know which packet protocol is writing. type BufferPacketType int const ( // RTPBufferPacket indicates the Buffer that is handling RTP packets. RTPBufferPacket BufferPacketType = 1 // RTCPBufferPacket indicates the Buffer that is handling RTCP packets. RTCPBufferPacket BufferPacketType = 2 ) // Buffer allows writing packets to an intermediate buffer, which can then be read form. // This is verify similar to bytes.Buffer but avoids combining multiple writes into a single read. type Buffer struct { mutex sync.Mutex // this is a circular buffer. If head <= tail, then the useful // data is in the interval [head, tail[. If tail < head, then // the useful data is the union of [head, len[ and [0, tail[. // In order to avoid ambiguity when head = tail, we always leave // an unused byte in the buffer. data []byte head, tail int notify chan struct{} closed bool count int limitCount, limitSize int readDeadline *deadline.Deadline } const ( minSize = 2048 cutoffSize = 128 * 1024 maxSize = 4 * 1024 * 1024 ) // NewBuffer creates a new Buffer. func NewBuffer() *Buffer { return &Buffer{ notify: make(chan struct{}, 1), readDeadline: deadline.New(), } } // available returns true if the buffer is large enough to fit a packet // of the given size, taking overhead into account. func (b *Buffer) available(size int) bool { available := b.head - b.tail if available <= 0 { available += len(b.data) } // we interpret head=tail as empty, so always keep a byte free if size+2+1 > available { return false } return true } // grow increases the size of the buffer. If it returns nil, then the // buffer has been grown. It returns ErrFull if hits a limit. func (b *Buffer) grow() error { var newSize int if len(b.data) < cutoffSize { newSize = 2 * len(b.data) } else { newSize = 5 * len(b.data) / 4 } if newSize < minSize { newSize = minSize } if (b.limitSize <= 0 || sizeHardLimit) && newSize > maxSize { newSize = maxSize } // one byte slack if b.limitSize > 0 && newSize > b.limitSize+1 { newSize = b.limitSize + 1 } if newSize <= len(b.data) { return ErrFull } newData := make([]byte, newSize) var n int if b.head <= b.tail { // data was contiguous n = copy(newData, b.data[b.head:b.tail]) } else { // data was discontinuous n = copy(newData, b.data[b.head:]) n += copy(newData[n:], b.data[:b.tail]) } b.head = 0 b.tail = n b.data = newData return nil } // Write appends a copy of the packet data to the buffer. // Returns ErrFull if the packet doesn't fit. // // Note that the packet size is limited to 65536 bytes since v0.11.0 due to the internal data structure. func (b *Buffer) Write(packet []byte) (int, error) { //nolint:cyclop if len(packet) >= 0x10000 { return 0, errPacketTooBig } b.mutex.Lock() if b.closed { b.mutex.Unlock() return 0, io.ErrClosedPipe } if (b.limitCount > 0 && b.count >= b.limitCount) || (b.limitSize > 0 && b.size()+2+len(packet) > b.limitSize) { b.mutex.Unlock() return 0, ErrFull } // grow the buffer until the packet fits for !b.available(len(packet)) { err := b.grow() if err != nil { b.mutex.Unlock() return 0, err } } // store the length of the packet b.data[b.tail] = uint8(len(packet) >> 8) //nolint:gosec b.tail++ if b.tail >= len(b.data) { b.tail = 0 } b.data[b.tail] = uint8(len(packet)) //nolint:gosec b.tail++ if b.tail >= len(b.data) { b.tail = 0 } // store the packet n := copy(b.data[b.tail:], packet) b.tail += n if b.tail >= len(b.data) { // we reached the end, wrap around m := copy(b.data, packet[n:]) b.tail = m } b.count++ select { case b.notify <- struct{}{}: default: } b.mutex.Unlock() return len(packet), nil } // Read populates the given byte slice, returning the number of bytes read. // Blocks until data is available or the buffer is closed. // Returns io.ErrShortBuffer is the packet is too small to copy the Write. // Returns io.EOF if the buffer is closed. func (b *Buffer) Read(packet []byte) (n int, err error) { //nolint:gocognit,cyclop // Return immediately if the deadline is already exceeded. select { case <-b.readDeadline.Done(): return 0, &netError{ErrTimeout, true, true} default: } for { b.mutex.Lock() if b.head != b.tail { //nolint:nestif // decode the packet size n1 := b.data[b.head] b.head++ if b.head >= len(b.data) { b.head = 0 } n2 := b.data[b.head] b.head++ if b.head >= len(b.data) { b.head = 0 } count := int((uint16(n1) << 8) | uint16(n2)) // determine the number of bytes we'll actually copy copied := count if copied > len(packet) { copied = len(packet) } // copy the data if b.head+copied < len(b.data) { copy(packet, b.data[b.head:b.head+copied]) } else { k := copy(packet, b.data[b.head:]) copy(packet[k:], b.data[:copied-k]) } // advance head, discarding any data that wasn't copied b.head += count if b.head >= len(b.data) { b.head -= len(b.data) } if b.head == b.tail { // the buffer is empty, reset to beginning // in order to improve cache locality. b.head = 0 b.tail = 0 } b.count-- b.mutex.Unlock() if copied < count { return copied, io.ErrShortBuffer } return copied, nil } if b.closed { b.mutex.Unlock() return 0, io.EOF } b.mutex.Unlock() select { case <-b.readDeadline.Done(): return 0, &netError{ErrTimeout, true, true} case <-b.notify: } } } // Close the buffer, unblocking any pending reads. // Data in the buffer can still be read, Read will return io.EOF only when empty. func (b *Buffer) Close() (err error) { b.mutex.Lock() if b.closed { b.mutex.Unlock() return nil } b.closed = true close(b.notify) b.mutex.Unlock() return nil } // Count returns the number of packets in the buffer. func (b *Buffer) Count() int { b.mutex.Lock() defer b.mutex.Unlock() return b.count } // SetLimitCount controls the maximum number of packets that can be buffered. // Causes Write to return ErrFull when this limit is reached. // A zero value will disable this limit. func (b *Buffer) SetLimitCount(limit int) { b.mutex.Lock() defer b.mutex.Unlock() b.limitCount = limit } // Size returns the total byte size of packets in the buffer, including // a small amount of administrative overhead. func (b *Buffer) Size() int { b.mutex.Lock() defer b.mutex.Unlock() return b.size() } func (b *Buffer) size() int { size := b.tail - b.head if size < 0 { size += len(b.data) } return size } // SetLimitSize controls the maximum number of bytes that can be buffered. // Causes Write to return ErrFull when this limit is reached. // A zero value means 4MB since v0.11.0. // // User can set packetioSizeHardLimit build tag to enable 4MB hard limit. // When packetioSizeHardLimit build tag is set, SetLimitSize exceeding // the hard limit will be silently discarded. func (b *Buffer) SetLimitSize(limit int) { b.mutex.Lock() defer b.mutex.Unlock() b.limitSize = limit } // SetReadDeadline sets the deadline for the Read operation. // Setting to zero means no deadline. func (b *Buffer) SetReadDeadline(t time.Time) error { b.readDeadline.Set(t) return nil } golang-github-pion-transport-v3-3.0.8/packetio/buffer_test.go000066400000000000000000000355231507057301700242470ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package packetio import ( "errors" "fmt" "io" "net" "sync/atomic" "testing" "time" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/assert" ) func TestBuffer(t *testing.T) { assert := assert.New(t) buffer := NewBuffer() packet := make([]byte, 4) // Write once n, err := buffer.Write([]byte{0, 1}) assert.NoError(err) assert.Equal(2, n) // Read once n, err = buffer.Read(packet) assert.NoError(err) assert.Equal(2, n) assert.Equal([]byte{0, 1}, packet[:n]) // Read deadline err = buffer.SetReadDeadline(time.Unix(0, 1)) assert.NoError(err) n, err = buffer.Read(packet) var e net.Error assert.ErrorAs(err, &e) assert.True(e.Timeout()) assert.Equal(0, n) // Reset deadline err = buffer.SetReadDeadline(time.Time{}) assert.NoError(err) // Write twice n, err = buffer.Write([]byte{2, 3, 4}) assert.NoError(err) assert.Equal(3, n) n, err = buffer.Write([]byte{5, 6, 7}) assert.NoError(err) assert.Equal(3, n) // Read twice n, err = buffer.Read(packet) assert.NoError(err) assert.Equal(3, n) assert.Equal([]byte{2, 3, 4}, packet[:n]) n, err = buffer.Read(packet) assert.NoError(err) assert.Equal(3, n) assert.Equal([]byte{5, 6, 7}, packet[:n]) // Write once prior to close. _, err = buffer.Write([]byte{3}) assert.NoError(err) // Close err = buffer.Close() assert.NoError(err) // Future writes will error _, err = buffer.Write([]byte{4}) assert.Error(err) // But we can read the remaining data. n, err = buffer.Read(packet) assert.NoError(err) assert.Equal(1, n) assert.Equal([]byte{3}, packet[:n]) // Until EOF _, err = buffer.Read(packet) assert.Equal(io.EOF, err) } func testWraparound(t *testing.T, grow bool) { t.Helper() assert := assert.New(t) buffer := NewBuffer() err := buffer.grow() assert.NoError(err) buffer.head = len(buffer.data) - 13 buffer.tail = buffer.head p1 := []byte{1, 2, 3} p2 := []byte{4, 5, 6} p3 := []byte{7, 8, 9} p4 := []byte{10, 11, 12} _, err = buffer.Write(p1) assert.NoError(err) _, err = buffer.Write(p2) assert.NoError(err) _, err = buffer.Write(p3) assert.NoError(err) packet := make([]byte, 10) n, err := buffer.Read(packet) assert.NoError(err) assert.Equal(p1, packet[:n]) if grow { err = buffer.grow() assert.NoError(err) } n, err = buffer.Read(packet) assert.NoError(err) assert.Equal(p2, packet[:n]) _, err = buffer.Write(p4) assert.NoError(err) n, err = buffer.Read(packet) assert.NoError(err) assert.Equal(p3, packet[:n]) n, err = buffer.Read(packet) assert.NoError(err) assert.Equal(p4, packet[:n]) if !grow { assert.Equal(len(buffer.data), minSize) } else { assert.Equal(len(buffer.data), 2*minSize) } } func TestBufferWraparound(t *testing.T) { testWraparound(t, false) } func TestBufferWraparoundGrow(t *testing.T) { testWraparound(t, true) } func TestBufferAsync(t *testing.T) { assert := assert.New(t) buffer := NewBuffer() // Start up a goroutine to start a blocking read. done := make(chan struct{}) go func() { packet := make([]byte, 4) n, err := buffer.Read(packet) assert.NoError(err) assert.Equal(2, n) assert.Equal([]byte{0, 1}, packet[:n]) _, err = buffer.Read(packet) assert.Equal(io.EOF, err) close(done) }() // Wait for the reader to start reading. time.Sleep(time.Millisecond) // Write once n, err := buffer.Write([]byte{0, 1}) assert.NoError(err) assert.Equal(2, n) // Wait for the reader to start reading again. time.Sleep(time.Millisecond) // Close will unblock the reader. err = buffer.Close() assert.NoError(err) <-done } func TestBufferLimitCount(t *testing.T) { assert := assert.New(t) buffer := NewBuffer() buffer.SetLimitCount(2) assert.Equal(0, buffer.Count()) // Write twice n, err := buffer.Write([]byte{0, 1}) assert.NoError(err) assert.Equal(2, n) assert.Equal(1, buffer.Count()) n, err = buffer.Write([]byte{2, 3}) assert.NoError(err) assert.Equal(2, n) assert.Equal(2, buffer.Count()) // Over capacity _, err = buffer.Write([]byte{4, 5}) assert.Equal(ErrFull, err) assert.Equal(2, buffer.Count()) // Read once packet := make([]byte, 4) n, err = buffer.Read(packet) assert.NoError(err) assert.Equal(2, n) assert.Equal([]byte{0, 1}, packet[:n]) assert.Equal(1, buffer.Count()) // Write once n, err = buffer.Write([]byte{6, 7}) assert.NoError(err) assert.Equal(2, n) assert.Equal(2, buffer.Count()) // Over capacity _, err = buffer.Write([]byte{8, 9}) assert.Equal(ErrFull, err) assert.Equal(2, buffer.Count()) // Read twice n, err = buffer.Read(packet) assert.NoError(err) assert.Equal(2, n) assert.Equal([]byte{2, 3}, packet[:n]) assert.Equal(1, buffer.Count()) n, err = buffer.Read(packet) assert.NoError(err) assert.Equal(2, n) assert.Equal([]byte{6, 7}, packet[:n]) assert.Equal(0, buffer.Count()) // Nothing left. err = buffer.Close() assert.NoError(err) } func TestBufferLimitSize(t *testing.T) { assert := assert.New(t) buffer := NewBuffer() buffer.SetLimitSize(11) assert.Equal(0, buffer.Size()) // Write twice n, err := buffer.Write([]byte{0, 1}) assert.NoError(err) assert.Equal(2, n) assert.Equal(4, buffer.Size()) n, err = buffer.Write([]byte{2, 3}) assert.NoError(err) assert.Equal(2, n) assert.Equal(8, buffer.Size()) // Over capacity _, err = buffer.Write([]byte{4, 5}) assert.Equal(ErrFull, err) assert.Equal(8, buffer.Size()) // Cheeky write at exact size. n, err = buffer.Write([]byte{6}) assert.NoError(err) assert.Equal(1, n) assert.Equal(11, buffer.Size()) // Read once packet := make([]byte, 4) n, err = buffer.Read(packet) assert.NoError(err) assert.Equal(2, n) assert.Equal([]byte{0, 1}, packet[:n]) assert.Equal(7, buffer.Size()) // Write once n, err = buffer.Write([]byte{7, 8}) assert.NoError(err) assert.Equal(2, n) assert.Equal(11, buffer.Size()) // Over capacity _, err = buffer.Write([]byte{9, 10}) assert.Equal(ErrFull, err) assert.Equal(11, buffer.Size()) // Read everything n, err = buffer.Read(packet) assert.NoError(err) assert.Equal(2, n) assert.Equal([]byte{2, 3}, packet[:n]) assert.Equal(7, buffer.Size()) n, err = buffer.Read(packet) assert.NoError(err) assert.Equal(1, n) assert.Equal([]byte{6}, packet[:n]) assert.Equal(4, buffer.Size()) n, err = buffer.Read(packet) assert.NoError(err) assert.Equal(2, n) assert.Equal([]byte{7, 8}, packet[:n]) assert.Equal(0, buffer.Size()) // Nothing left. err = buffer.Close() assert.NoError(err) } func TestBufferLimitSizes(t *testing.T) { if sizeHardLimit { t.Skip("skipping since packetioSizeHardLimit is enabled") } sizes := []int{ 128 * 1024, 1024 * 1024, 8 * 1024 * 1024, 0, // default } const headerSize = 2 const packetSize = 0x8000 for _, size := range sizes { size := size name := "default" if size > 0 { name = fmt.Sprintf("%dkBytes", size/1024) } t.Run(name, func(t *testing.T) { assert := assert.New(t) buffer := NewBuffer() if size == 0 { size = maxSize } else { buffer.SetLimitSize(size + headerSize) } now := time.Now() assert.NoError(buffer.SetReadDeadline(now.Add(5 * time.Second))) // Set deadline to avoid test deadlock nPackets := size / (packetSize + headerSize) for i := 0; i < nPackets; i++ { _, err := buffer.Write(make([]byte, packetSize)) assert.NoError(err) } // Next write is expected to be errored. _, err := buffer.Write(make([]byte, packetSize)) assert.Error(err, ErrFull) packet := make([]byte, size) for i := 0; i < nPackets; i++ { n, err := buffer.Read(packet) assert.NoError(err) assert.Equal(packetSize, n) if err != nil { assert.FailNow("Read failed", err) } } }) } } func TestBufferMisc(t *testing.T) { assert := assert.New(t) buffer := NewBuffer() // Write once n, err := buffer.Write([]byte{0, 1, 2, 3}) assert.NoError(err) assert.Equal(4, n) // Try to read with a short buffer packet := make([]byte, 3) _, err = buffer.Read(packet) assert.Equal(io.ErrShortBuffer, err) // Close err = buffer.Close() assert.NoError(err) // Make sure you can Close twice err = buffer.Close() assert.NoError(err) } var errTooManyCallOfGetBuffer = errors.New("too many call of getBuffer") func TestBufferAlloc(t *testing.T) { packet := make([]byte, 1024) const countTolerance = 1 test := func(fn func(func() (*Buffer, error), int, *error) func(), count int, maxVal float64) func(t *testing.T) { return func(t *testing.T) { t.Helper() const nRuns = 100 // Create buffers in advance to avoid measuring allocs in NewBuffer() // +1 buffer for warm-up run buffers := make([]*Buffer, 0, nRuns+1) for i := 0; i < nRuns+1; i++ { buffers = append(buffers, NewBuffer()) } var iBuffer int getBuffer := func() (*Buffer, error) { if iBuffer >= len(buffers) { return nil, errTooManyCallOfGetBuffer } ret := buffers[iBuffer] iBuffer++ return ret, nil } var err error // AllocsPerRun calls the func once as a warm-up and then call it specified times allocs := testing.AllocsPerRun(nRuns, fn(getBuffer, count, &err)) assert.NoError(t, err) assert.LessOrEqualf(t, allocs, maxVal+countTolerance, "count=%d, max=%f+%d, got %f", count, maxVal, countTolerance, allocs) } } // Write (1024+2)*count bytes writer := func(getBuffer func() (*Buffer, error), count int, errOut *error) func() { return func() { // Call only buffer.Write() on the non-error paths to avoid wrong count of allocs buffer, err := getBuffer() // getBuffer doesn't alloc if err != nil { *errOut = err return } for i := 0; i < count; i++ { if _, err := buffer.Write(packet); err != nil { *errOut = fmt.Errorf("write: %w", err) return } } } } // Buffer size will be grown as // 2048 -> 4096 -> 8192 -> 16384 -> 32768 -> 65536 -> 131072 -> 163840 -> 204800 // -> 256000 -> 320000 -> 400000 -> 500000 -> 625000 -> 781250 -> 976562 -> 1220702 // based on the logic in Buffer.grow() t.Run("10 writes", test(writer, 10, 4)) // 10260 bytes t.Run("100 writes", test(writer, 100, 7)) // 102600 bytes t.Run("200 writes", test(writer, 200, 10)) // 205200 bytes t.Run("400 writes", test(writer, 400, 13)) // 410400 bytes t.Run("1000 writes", test(writer, 1000, 17)) // 1026000 bytes // Read and write same times, so the buffer size should never grow wr := func(getBuffer func() (*Buffer, error), count int, errOut *error) func() { return func() { // Call only buffer.Write() on the non-error paths to avoid wrong count of allocs buffer, err := getBuffer() // getBuffer doesn't alloc if err != nil { *errOut = err return } for i := 0; i < count; i++ { if _, err := buffer.Write(packet); err != nil { *errOut = fmt.Errorf("write: %w", err) return } if _, err := buffer.Read(packet); err != nil { *errOut = fmt.Errorf("read: %w", err) return } } } } t.Run("10 writes and reads", test(wr, 10, 1)) t.Run("100 writes and reads", test(wr, 100, 1)) t.Run("1000 writes and reads", test(wr, 1000, 1)) t.Run("10000 writes and reads", test(wr, 10000, 1)) } func benchmarkBufferWR(b *testing.B, size int64, write bool, grow int) { // nolint:unparam b.Helper() buffer := NewBuffer() packet := make([]byte, size) // Grow the buffer first pad := make([]byte, 1022) for buffer.Size() < grow { _, err := buffer.Write(pad) if err != nil { b.Fatalf("Write: %v", err) } } for buffer.Size() > 0 { _, err := buffer.Read(pad) if err != nil { b.Fatalf("Write: %v", err) } } if write { _, err := buffer.Write(packet) if err != nil { b.Fatalf("Write: %v", err) } } b.SetBytes(size) b.ResetTimer() for i := 0; i < b.N; i++ { _, err := buffer.Write(packet) if err != nil { b.Fatalf("Write: %v", err) } _, err = buffer.Read(packet) if err != nil { b.Fatalf("Read: %v", err) } } } // In this benchmark, the buffer is often empty, which is hopefully // typical of real usage. func BenchmarkBufferWR14(b *testing.B) { benchmarkBufferWR(b, 14, false, 128000) } func BenchmarkBufferWR140(b *testing.B) { benchmarkBufferWR(b, 140, false, 128000) } func BenchmarkBufferWR1400(b *testing.B) { benchmarkBufferWR(b, 1400, false, 128000) } // Here, the buffer never becomes empty, which forces wraparound. func BenchmarkBufferWWR14(b *testing.B) { benchmarkBufferWR(b, 14, true, 128000) } func BenchmarkBufferWWR140(b *testing.B) { benchmarkBufferWR(b, 140, true, 128000) } func BenchmarkBufferWWR1400(b *testing.B) { benchmarkBufferWR(b, 1400, true, 128000) } func benchmarkBuffer(b *testing.B, size int64) { b.Helper() buffer := NewBuffer() b.SetBytes(size) done := make(chan struct{}) go func() { packet := make([]byte, size) for { _, err := buffer.Read(packet) if errors.Is(err, io.EOF) { break } else if err != nil { b.Error(err) break } } close(done) }() packet := make([]byte, size) b.ResetTimer() for i := 0; i < b.N; i++ { var err error for { _, err = buffer.Write(packet) if !errors.Is(err, ErrFull) { break } time.Sleep(time.Microsecond) } if err != nil { b.Fatal(err) } } err := buffer.Close() if err != nil { b.Fatal(err) } <-done } func BenchmarkBuffer14(b *testing.B) { benchmarkBuffer(b, 14) } func BenchmarkBuffer140(b *testing.B) { benchmarkBuffer(b, 140) } func BenchmarkBuffer1400(b *testing.B) { benchmarkBuffer(b, 1400) } func TestBufferConcurrentRead(t *testing.T) { assert := assert.New(t) buffer := NewBuffer() packet := make([]byte, 4) // Write twice n, err := buffer.Write([]byte{2, 3, 4}) assert.NoError(err) assert.Equal(3, n) n, err = buffer.Write([]byte{5, 6, 7}) assert.NoError(err) assert.Equal(3, n) // Read twice n, err = buffer.Read(packet) assert.NoError(err) assert.Equal(3, n) assert.Equal([]byte{2, 3, 4}, packet[:n]) n, err = buffer.Read(packet) assert.NoError(err) assert.Equal(3, n) assert.Equal([]byte{5, 6, 7}, packet[:n]) errCh := make(chan error, 2) readIntoErr := func() { packet := make([]byte, 4) _, readErr := buffer.Read(packet) errCh <- readErr } go readIntoErr() go readIntoErr() // Close err = buffer.Close() assert.NoError(err) err = <-errCh assert.Equal(io.EOF, err) err = <-errCh assert.Equal(io.EOF, err) } func TestBufferConcurrentReadWrite(t *testing.T) { defer test.TimeOut(time.Second * 5).Stop() assert := assert.New(t) buffer := NewBuffer() numPkts := 1000 var numRead uint64 allRead := make(chan struct{}) readPkts := func(count int) { packet := make([]byte, 4) for i := 0; i < count; i++ { _, readErr := buffer.Read(packet) if readErr != nil { return } if atomic.AddUint64(&numRead, 1) == uint64(numPkts) { //nolint:gosec close(allRead) return } } } go readPkts(numPkts) go readPkts(numPkts / 100) for i := 0; i < numPkts; i++ { _, writeErr := buffer.Write([]byte{2, 3, 4}) assert.NoError(writeErr) } <-allRead assert.NoError(buffer.Close()) } golang-github-pion-transport-v3-3.0.8/packetio/errors.go000066400000000000000000000011231507057301700232400ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package packetio import ( "errors" ) // netError implements net.Error. type netError struct { error timeout, temporary bool } func (e *netError) Timeout() bool { return e.timeout } func (e *netError) Temporary() bool { return e.temporary } var ( // ErrFull is returned when the buffer has hit the configured limits. ErrFull = errors.New("packetio.Buffer is full, discarding write") // ErrTimeout is returned when a deadline has expired. ErrTimeout = errors.New("i/o timeout") ) golang-github-pion-transport-v3-3.0.8/packetio/hardlimit.go000066400000000000000000000003251507057301700237040ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build packetioSizeHardlimit // +build packetioSizeHardlimit package packetio const sizeHardLimit = true golang-github-pion-transport-v3-3.0.8/packetio/no_hardlimit.go000066400000000000000000000003301507057301700243740ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !packetioSizeHardlimit // +build !packetioSizeHardlimit package packetio const sizeHardLimit = false golang-github-pion-transport-v3-3.0.8/renovate.json000066400000000000000000000001731507057301700223200ustar00rootroot00000000000000{ "$schema": "https://docs.renovatebot.com/renovate-schema.json", "extends": [ "github>pion/renovate-config" ] } golang-github-pion-transport-v3-3.0.8/replaydetector/000077500000000000000000000000001507057301700226275ustar00rootroot00000000000000golang-github-pion-transport-v3-3.0.8/replaydetector/fixedbig.go000066400000000000000000000030551507057301700247420ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package replaydetector import ( "fmt" ) // fixedBigInt is the fix-sized multi-word integer. type fixedBigInt struct { bits []uint64 n uint msbMask uint64 } // newFixedBigInt creates a new fix-sized multi-word int. func newFixedBigInt(n uint) *fixedBigInt { chunkSize := (n + 63) / 64 if chunkSize == 0 { chunkSize = 1 } return &fixedBigInt{ bits: make([]uint64, chunkSize), n: n, msbMask: (1 << (64 - n%64)) - 1, } } // Lsh is the left shift operation. func (s *fixedBigInt) Lsh(n uint) { //nolint:varnamelen if n == 0 { return } nChunk := int(n / 64) //nolint:gosec nN := n % 64 for i := len(s.bits) - 1; i >= 0; i-- { var carry uint64 if i-nChunk >= 0 { carry = s.bits[i-nChunk] << nN if i-nChunk-1 >= 0 { carry |= s.bits[i-nChunk-1] >> (64 - nN) } } s.bits[i] = (s.bits[i] << n) | carry } s.bits[len(s.bits)-1] &= s.msbMask } // Bit returns i-th bit of the fixedBigInt. func (s *fixedBigInt) Bit(i uint) uint { if i >= s.n { return 0 } chunk := i / 64 pos := i % 64 if s.bits[chunk]&(1<= s.n { return } chunk := i / 64 pos := i % 64 s.bits[chunk] |= 1 << pos } // String returns string representation of fixedBigInt. func (s *fixedBigInt) String() string { var out string for i := len(s.bits) - 1; i >= 0; i-- { out += fmt.Sprintf("%016X", s.bits[i]) } return out } golang-github-pion-transport-v3-3.0.8/replaydetector/fixedbig_test.go000066400000000000000000000030451507057301700260000ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package replaydetector import ( "fmt" ) func Example_fixedBigInt_SetBit() { bi := newFixedBigInt(224) bi.SetBit(0) fmt.Println(bi.String()) bi.Lsh(1) fmt.Println(bi.String()) bi.Lsh(0) fmt.Println(bi.String()) bi.SetBit(10) fmt.Println(bi.String()) bi.Lsh(20) fmt.Println(bi.String()) bi.SetBit(80) fmt.Println(bi.String()) bi.Lsh(4) fmt.Println(bi.String()) bi.SetBit(130) fmt.Println(bi.String()) bi.Lsh(64) fmt.Println(bi.String()) bi.SetBit(7) fmt.Println(bi.String()) bi.Lsh(129) fmt.Println(bi.String()) for i := 0; i < 256; i++ { bi.Lsh(1) bi.SetBit(0) } fmt.Println(bi.String()) // output: // 0000000000000000000000000000000000000000000000000000000000000001 // 0000000000000000000000000000000000000000000000000000000000000002 // 0000000000000000000000000000000000000000000000000000000000000002 // 0000000000000000000000000000000000000000000000000000000000000402 // 0000000000000000000000000000000000000000000000000000000040200000 // 0000000000000000000000000000000000000000000100000000000040200000 // 0000000000000000000000000000000000000000001000000000000402000000 // 0000000000000000000000000000000400000000001000000000000402000000 // 0000000000000004000000000010000000000004020000000000000000000000 // 0000000000000004000000000010000000000004020000000000000000000080 // 0000000004000000000000000000010000000000000000000000000000000000 // 00000000FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF } golang-github-pion-transport-v3-3.0.8/replaydetector/replaydetector.go000066400000000000000000000065561507057301700262200ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package replaydetector provides packet replay detection algorithm. package replaydetector // ReplayDetector is the interface of sequence replay detector. type ReplayDetector interface { // Check returns true if given sequence number is not replayed. // Call accept() to mark the packet is received properly. // The return value of accept() indicates whether the accepted packet is // has the latest observed sequence number. Check(seq uint64) (accept func() bool, ok bool) } // nop is a no-op func that is returned in the case that Check() fails. func nop() bool { return false } type slidingWindowDetector struct { latestSeq uint64 maxSeq uint64 windowSize uint mask *fixedBigInt } // New creates ReplayDetector. // Created ReplayDetector doesn't allow wrapping. // It can handle monotonically increasing sequence number up to // full 64bit number. It is suitable for DTLS replay protection. func New(windowSize uint, maxSeq uint64) ReplayDetector { return &slidingWindowDetector{ maxSeq: maxSeq, windowSize: windowSize, mask: newFixedBigInt(windowSize), } } func (d *slidingWindowDetector) Check(seq uint64) (func() bool, bool) { if seq > d.maxSeq { // Exceeded upper limit. return nop, false } if seq <= d.latestSeq { if d.latestSeq >= uint64(d.windowSize)+seq { return nop, false } if d.mask.Bit(uint(d.latestSeq-seq)) != 0 { // The sequence number is duplicated. return nop, false } } return func() bool { latest := seq == 0 if seq > d.latestSeq { // Update the head of the window. d.mask.Lsh(uint(seq - d.latestSeq)) d.latestSeq = seq latest = true } diff := (d.latestSeq - seq) % d.maxSeq d.mask.SetBit(uint(diff)) return latest }, true } // WithWrap creates ReplayDetector allowing sequence wrapping. // This is suitable for short bit width counter like SRTP and SRTCP. func WithWrap(windowSize uint, maxSeq uint64) ReplayDetector { return &wrappedSlidingWindowDetector{ maxSeq: maxSeq, windowSize: windowSize, mask: newFixedBigInt(windowSize), } } type wrappedSlidingWindowDetector struct { latestSeq uint64 maxSeq uint64 windowSize uint mask *fixedBigInt init bool } func (d *wrappedSlidingWindowDetector) Check(seq uint64) (func() bool, bool) { if seq > d.maxSeq { // Exceeded upper limit. return nop, false } if !d.init { if seq != 0 { d.latestSeq = seq - 1 } else { d.latestSeq = d.maxSeq } d.init = true } diff := int64(d.latestSeq) - int64(seq) //nolint:gosec // GG115 TODO check // Wrap the number. if diff > int64(d.maxSeq)/2 { //nolint:gosec // GG115 TODO check diff -= int64(d.maxSeq + 1) //nolint:gosec // GG115 TODO check } else if diff <= -int64(d.maxSeq)/2 { //nolint:gosec // GG115 TODO check diff += int64(d.maxSeq + 1) //nolint:gosec // GG115 TODO check } if diff >= int64(d.windowSize) { //nolint:gosec // GG115 TODO check // Too old. return nop, false } if diff >= 0 { if d.mask.Bit(uint(diff)) != 0 { // The sequence number is duplicated. return nop, false } } return func() bool { latest := false if diff < 0 { // Update the head of the window. d.mask.Lsh(uint(-diff)) d.latestSeq = seq latest = true d.mask.SetBit(0) } else { d.mask.SetBit(uint(diff)) } return latest }, true } golang-github-pion-transport-v3-3.0.8/replaydetector/replaydetector_test.go000066400000000000000000000200461507057301700272450ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package replaydetector import ( "testing" "github.com/stretchr/testify/assert" ) type testCase struct { windowSize uint maxSeq uint64 input []uint64 valid []bool latest []bool expected []uint64 } const ( largeSeq = 0x100000000000 hugeSeq = 0x1000000000000 ) var commonCases = map[string]testCase{ //nolint:gochecknoglobals "Continuous": { 16, 0x0000FFFFFFFFFFFF, []uint64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}, []bool{ true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, }, []bool{ true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, }, []uint64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}, }, "ValidLargeJump": { 16, 0x0000FFFFFFFFFFFF, []uint64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, largeSeq, 11, largeSeq + 1, largeSeq + 2, largeSeq + 3}, []bool{ true, true, true, true, true, true, true, true, true, true, true, false, true, true, true, }, []bool{ true, true, true, true, true, true, true, true, true, true, true, false, true, true, true, }, []uint64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, largeSeq, largeSeq + 1, largeSeq + 2, largeSeq + 3}, }, "InvalidLargeJump": { 16, 0x0000FFFFFFFFFFFF, []uint64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, hugeSeq, 11, 12, 13, 14, 15}, []bool{ true, true, true, true, true, true, true, true, true, true, false, true, true, true, true, true, }, []bool{ true, true, true, true, true, true, true, true, true, true, false, true, true, true, true, true, }, []uint64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15}, }, "DuplicateAfterValidJump": { 196, 0x0000FFFFFFFFFFFF, []uint64{0, 1, 2, 129, 0, 1, 2}, []bool{ true, true, true, true, false, false, false, }, []bool{ true, true, true, true, false, false, false, }, []uint64{0, 1, 2, 129}, }, "DuplicateAfterInvalidJump": { 196, 0x0000FFFFFFFFFFFF, []uint64{0, 1, 2, hugeSeq, 0, 1, 2}, []bool{ true, true, true, false, false, false, false, }, []bool{ true, true, true, false, false, false, false, }, []uint64{0, 1, 2}, }, "ContinuousOffset": { 16, 0x0000FFFFFFFFFFFF, []uint64{100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114}, []bool{ true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, }, []bool{ true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, }, []uint64{100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114}, }, "Reordered": { 128, 0x0000FFFFFFFFFFFF, []uint64{96, 64, 16, 80, 32, 48, 8, 24, 88, 40, 128, 56, 72, 112, 104, 120}, []bool{ true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, }, []bool{ true, false, false, false, false, false, false, false, false, false, true, false, false, false, false, false, }, []uint64{96, 64, 16, 80, 32, 48, 8, 24, 88, 40, 128, 56, 72, 112, 104, 120}, }, "Old": { 100, 0x0000FFFFFFFFFFFF, []uint64{24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 8, 16}, []bool{ true, true, true, true, true, true, true, true, true, true, true, true, true, true, false, false, }, []bool{ true, true, true, true, true, true, true, true, true, true, true, true, true, true, false, false, }, []uint64{24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128}, }, "ContinuousReplayed": { 8, 0x0000FFFFFFFFFFFF, []uint64{16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}, []bool{ true, true, true, true, true, true, true, true, true, true, false, false, false, false, false, false, false, false, false, false, }, []bool{ true, true, true, true, true, true, true, true, true, true, false, false, false, false, false, false, false, false, false, false, }, []uint64{16, 17, 18, 19, 20, 21, 22, 23, 24, 25}, }, "ReplayedLater": { 128, 0x0000FFFFFFFFFFFF, []uint64{16, 32, 48, 64, 80, 96, 112, 128, 16, 32, 48, 64, 80, 96, 112, 128}, []bool{ true, true, true, true, true, true, true, true, false, false, false, false, false, false, false, false, }, []bool{ true, true, true, true, true, true, true, true, false, false, false, false, false, false, false, false, }, []uint64{16, 32, 48, 64, 80, 96, 112, 128}, }, "ReplayedQuick": { 128, 0x0000FFFFFFFFFFFF, []uint64{16, 16, 32, 32, 48, 48, 64, 64, 80, 80, 96, 96, 112, 112, 128, 128}, []bool{ true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, }, []bool{ true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, }, []uint64{16, 32, 48, 64, 80, 96, 112, 128}, }, "Strict": { 0, 0x0000FFFFFFFFFFFF, []uint64{1, 3, 2, 4, 5, 6, 7, 8, 9, 10}, []bool{ true, true, false, true, true, true, true, true, true, true, }, []bool{ true, true, false, true, true, true, true, true, true, true, }, []uint64{1, 3, 4, 5, 6, 7, 8, 9, 10}, }, "Overflow": { 128, 0x0000FFFFFFFFFFFF, []uint64{0x0000FFFFFFFFFFFE, 0x0000FFFFFFFFFFFF, 0x0001000000000000, 0x0001000000000001}, []bool{ true, true, false, false, }, []bool{ true, true, false, false, }, []uint64{0x0000FFFFFFFFFFFE, 0x0000FFFFFFFFFFFF}, }, } func TestReplayDetector(t *testing.T) { for name, testCase := range commonCases { testCase := testCase t.Run(name, func(t *testing.T) { det := New(testCase.windowSize, testCase.maxSeq) var out []uint64 for i, seq := range testCase.input { accept, ok := det.Check(seq) assert.Equal(t, testCase.valid[i], ok, "Unexpected validity") if ok { out = append(out, seq) } assert.Equal(t, testCase.latest[i], accept(), "Unexpected sequence latest status") } assert.Equal(t, testCase.expected, out, "Wrong replay detection result") }) } } func TestReplayDetectorWrapped(t *testing.T) { cases := map[string]testCase{ "WrapContinuous": { 64, 0xFFFF, []uint64{0xFFFC, 0xFFFD, 0xFFFE, 0xFFFF, 0x0000, 0x0001, 0x0002, 0x0003}, []bool{ true, true, true, true, true, true, true, true, }, []bool{ true, true, true, true, true, true, true, true, }, []uint64{0xFFFC, 0xFFFD, 0xFFFE, 0xFFFF, 0x0000, 0x0001, 0x0002, 0x0003}, }, "WrapReordered": { 64, 0xFFFF, []uint64{0xFFFD, 0xFFFC, 0x0002, 0xFFFE, 0x0000, 0x0001, 0xFFFF, 0x0003}, []bool{ true, true, true, true, true, true, true, true, }, []bool{ true, false, true, false, false, false, false, true, }, []uint64{0xFFFD, 0xFFFC, 0x0002, 0xFFFE, 0x0000, 0x0001, 0xFFFF, 0x0003}, }, "WrapReorderedReplayed": { 64, 0xFFFF, []uint64{0xFFFD, 0xFFFC, 0xFFFC, 0x0002, 0xFFFE, 0xFFFC, 0x0000, 0x0001, 0x0001, 0xFFFF, 0x0001, 0x0003}, []bool{ true, true, false, true, true, false, true, true, false, true, false, true, }, []bool{ true, false, false, true, false, false, false, false, false, false, false, true, }, []uint64{0xFFFD, 0xFFFC, 0x0002, 0xFFFE, 0x0000, 0x0001, 0xFFFF, 0x0003}, }, "BeforeWrapReplayed": { 64, 0xFFFF, []uint64{0x0, 0xFFFF, 0xFFFF}, []bool{ true, true, false, }, []bool{ true, false, false, }, []uint64{0x0, 0xFFFF}, }, } for name, c := range commonCases { _, ok := cases[name] assert.False(t, ok, "Duplicate test case name: %q", name) cases[name] = c } for name, c := range cases { testCase := c t.Run(name, func(t *testing.T) { det := WithWrap(testCase.windowSize, testCase.maxSeq) var out []uint64 for i, seq := range testCase.input { accept, ok := det.Check(seq) assert.Equal(t, testCase.valid[i], ok, "Unexpected validity") if ok { out = append(out, seq) } assert.Equal(t, testCase.latest[i], accept(), "Unexpected sequence latest status") } assert.Equal(t, testCase.expected, out, "Wrong replay detection result") }) } } golang-github-pion-transport-v3-3.0.8/stdnet/000077500000000000000000000000001507057301700211025ustar00rootroot00000000000000golang-github-pion-transport-v3-3.0.8/stdnet/net.go000066400000000000000000000103121507057301700222140ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package stdnet implements the transport.Net interface // using methods from Go's standard net package. package stdnet import ( "fmt" "net" "github.com/pion/transport/v3" "github.com/wlynxg/anet" ) const ( lo0String = "lo0String" udpString = "udp" ) // Net is an implementation of the net.Net interface // based on functions of the standard net package. type Net struct { interfaces []*transport.Interface } // NewNet creates a new StdNet instance. func NewNet() (*Net, error) { n := &Net{} return n, n.UpdateInterfaces() } // Compile-time assertion. var _ transport.Net = &Net{} // UpdateInterfaces updates the internal list of network interfaces // and associated addresses. func (n *Net) UpdateInterfaces() error { ifs := []*transport.Interface{} oifs, err := anet.Interfaces() if err != nil { return err } for i := range oifs { ifc := transport.NewInterface(oifs[i]) addrs, err := anet.InterfaceAddrsByInterface(&oifs[i]) if err != nil { return err } for _, addr := range addrs { ifc.AddAddress(addr) } ifs = append(ifs, ifc) } n.interfaces = ifs return nil } // Interfaces returns a slice of interfaces which are available on the // system. func (n *Net) Interfaces() ([]*transport.Interface, error) { return n.interfaces, nil } // InterfaceByIndex returns the interface specified by index. // // On Solaris, it returns one of the logical network interfaces // sharing the logical data link; for more precision use // InterfaceByName. func (n *Net) InterfaceByIndex(index int) (*transport.Interface, error) { for _, ifc := range n.interfaces { if ifc.Index == index { return ifc, nil } } return nil, fmt.Errorf("%w: index=%d", transport.ErrInterfaceNotFound, index) } // InterfaceByName returns the interface specified by name. func (n *Net) InterfaceByName(name string) (*transport.Interface, error) { for _, ifc := range n.interfaces { if ifc.Name == name { return ifc, nil } } return nil, fmt.Errorf("%w: %s", transport.ErrInterfaceNotFound, name) } // ListenPacket announces on the local network address. func (n *Net) ListenPacket(network string, address string) (net.PacketConn, error) { return net.ListenPacket(network, address) //nolint: noctx } // ListenUDP acts like ListenPacket for UDP networks. func (n *Net) ListenUDP(network string, locAddr *net.UDPAddr) (transport.UDPConn, error) { return net.ListenUDP(network, locAddr) } // Dial connects to the address on the named network. func (n *Net) Dial(network, address string) (net.Conn, error) { return net.Dial(network, address) //nolint: noctx } // DialUDP acts like Dial for UDP networks. func (n *Net) DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) { return net.DialUDP(network, laddr, raddr) } // ResolveIPAddr returns an address of IP end point. func (n *Net) ResolveIPAddr(network, address string) (*net.IPAddr, error) { return net.ResolveIPAddr(network, address) } // ResolveUDPAddr returns an address of UDP end point. func (n *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) { return net.ResolveUDPAddr(network, address) } // ResolveTCPAddr returns an address of TCP end point. func (n *Net) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) { return net.ResolveTCPAddr(network, address) } // DialTCP acts like Dial for TCP networks. func (n *Net) DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) { return net.DialTCP(network, laddr, raddr) } // ListenTCP acts like Listen for TCP networks. func (n *Net) ListenTCP(network string, laddr *net.TCPAddr) (transport.TCPListener, error) { l, err := net.ListenTCP(network, laddr) if err != nil { return nil, err } return tcpListener{l}, nil } type tcpListener struct { *net.TCPListener } func (l tcpListener) AcceptTCP() (transport.TCPConn, error) { return l.TCPListener.AcceptTCP() } type stdDialer struct { *net.Dialer } func (d stdDialer) Dial(network, address string) (net.Conn, error) { return d.Dialer.Dial(network, address) } // CreateDialer creates an instance of vnet.Dialer. func (n *Net) CreateDialer(d *net.Dialer) transport.Dialer { return stdDialer{d} } golang-github-pion-transport-v3-3.0.8/stdnet/net_test.go000066400000000000000000000135741507057301700232700ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package stdnet import ( "net" "testing" "github.com/pion/logging" "github.com/stretchr/testify/assert" ) func TestStdNet(t *testing.T) { //nolint:cyclop log := logging.NewDefaultLoggerFactory().NewLogger("test") t.Run("Interfaces", func(t *testing.T) { nw, err := NewNet() if !assert.Nil(t, err, "should succeed") { return } interfaces, err := nw.Interfaces() if !assert.NoError(t, err, "should succeed") { return } log.Debugf("interfaces: %+v", interfaces) for _, ifc := range interfaces { if ifc.Name == lo0String { _, err := ifc.Addrs() if !assert.NoError(t, err, "should succeed") { return } } if addrs, err := ifc.Addrs(); err == nil { for _, addr := range addrs { log.Debugf("[%d] %s:%s", ifc.Index, addr.Network(), addr.String()) } } } }) t.Run("ResolveUDPAddr", func(t *testing.T) { nw, err := NewNet() if !assert.Nil(t, err, "should succeed") { return } udpAddr, err := nw.ResolveUDPAddr(udpString, "localhost:1234") if !assert.NoError(t, err, "should succeed") { return } assert.Contains(t, []string{"127.0.0.1", "127.0.1.1"}, udpAddr.IP.String(), "should match") assert.Equal(t, 1234, udpAddr.Port, "should match") }) t.Run("ListenPacket", func(t *testing.T) { nw, err := NewNet() if !assert.Nil(t, err, "should succeed") { return } conn, err := nw.ListenPacket(udpString, "127.0.0.1:0") if !assert.NoError(t, err, "should succeed") { return } udpConn, ok := conn.(*net.UDPConn) assert.True(t, ok, "should succeed") log.Debugf("udpConn: %+v", udpConn) laddr := conn.LocalAddr().String() log.Debugf("laddr: %s", laddr) }) t.Run("ListenUDP random port", func(t *testing.T) { nw, err := NewNet() if !assert.Nil(t, err, "should succeed") { return } srcAddr := &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), } conn, err := nw.ListenUDP(udpString, srcAddr) assert.NoError(t, err, "should succeed") laddr := conn.LocalAddr().String() log.Debugf("laddr: %s", laddr) assert.NoError(t, conn.Close(), "should succeed") }) t.Run("Dial (UDP)", func(t *testing.T) { nw, err := NewNet() assert.Nil(t, err, "should succeed") conn, err := nw.Dial(udpString, "127.0.0.1:1234") assert.NoError(t, err, "should succeed") laddr := conn.LocalAddr() log.Debugf("laddr: %s", laddr.String()) raddr := conn.RemoteAddr() log.Debugf("raddr: %s", raddr.String()) assert.Equal(t, "127.0.0.1", laddr.(*net.UDPAddr).IP.String(), "should match") //nolint:forcetypeassert assert.True(t, laddr.(*net.UDPAddr).Port != 0, "should match") //nolint:forcetypeassert assert.Equal(t, "127.0.0.1:1234", raddr.String(), "should match") assert.NoError(t, conn.Close(), "should succeed") }) t.Run("DialUDP", func(t *testing.T) { nw, err := NewNet() assert.Nil(t, err, "should succeed") locAddr := &net.UDPAddr{ IP: net.IPv4(127, 0, 0, 1), Port: 0, } remAddr := &net.UDPAddr{ IP: net.IPv4(127, 0, 0, 1), Port: 1234, } conn, err := nw.DialUDP(udpString, locAddr, remAddr) assert.NoError(t, err, "should succeed") laddr := conn.LocalAddr() log.Debugf("laddr: %s", laddr.String()) raddr := conn.RemoteAddr() log.Debugf("raddr: %s", raddr.String()) assert.Equal(t, "127.0.0.1", laddr.(*net.UDPAddr).IP.String(), "should match") //nolint:forcetypeassert assert.True(t, laddr.(*net.UDPAddr).Port != 0, "should match") //nolint:forcetypeassert assert.Equal(t, "127.0.0.1:1234", raddr.String(), "should match") assert.NoError(t, conn.Close(), "should succeed") }) t.Run("UDPLoopback", func(t *testing.T) { nw, err := NewNet() assert.Nil(t, err, "should succeed") conn, err := nw.ListenPacket(udpString, "127.0.0.1:0") assert.NoError(t, err, "should succeed") laddr := conn.LocalAddr() msg := "PING!" n, err := conn.WriteTo([]byte(msg), laddr) assert.NoError(t, err, "should succeed") assert.Equal(t, len(msg), n, "should match") buf := make([]byte, 1000) n, addr, err := conn.ReadFrom(buf) assert.NoError(t, err, "should succeed") assert.Equal(t, len(msg), n, "should match") assert.Equal(t, msg, string(buf[:n]), "should match") assert.Equal(t, laddr.(*net.UDPAddr).String(), addr.(*net.UDPAddr).String(), "should match") //nolint:forcetypeassert assert.NoError(t, conn.Close(), "should succeed") }) t.Run("Dialer", func(t *testing.T) { nw, err := NewNet() assert.Nil(t, err, "should succeed") dialer := nw.CreateDialer(&net.Dialer{ LocalAddr: &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: 0, }, }) conn, err := dialer.Dial(udpString, "127.0.0.1:1234") assert.NoError(t, err, "should succeed") laddr := conn.LocalAddr() log.Debugf("laddr: %s", laddr.String()) raddr := conn.RemoteAddr() log.Debugf("raddr: %s", raddr.String()) assert.Equal(t, "127.0.0.1", laddr.(*net.UDPAddr).IP.String(), "should match") //nolint:forcetypeassert assert.True(t, laddr.(*net.UDPAddr).Port != 0, "should match") //nolint:forcetypeassert assert.Equal(t, "127.0.0.1:1234", raddr.String(), "should match") assert.NoError(t, conn.Close(), "should succeed") }) t.Run("Unexpected operations", func(t *testing.T) { // For portability of test, find a name of loopback interface name first var loName string ifs, err := net.Interfaces() assert.NoError(t, err, "should succeed") for _, ifc := range ifs { if ifc.Flags&net.FlagLoopback != 0 { loName = ifc.Name break } } nw, err := NewNet() assert.Nil(t, err, "should succeed") if len(loName) > 0 { // InterfaceByName ifc, err2 := nw.InterfaceByName(loName) assert.NoError(t, err2, "should succeed") assert.Equal(t, loName, ifc.Name, "should match") } _, err = nw.InterfaceByName("foo0") assert.Error(t, err, "should fail") }) } golang-github-pion-transport-v3-3.0.8/test/000077500000000000000000000000001507057301700205605ustar00rootroot00000000000000golang-github-pion-transport-v3-3.0.8/test/bridge.go000066400000000000000000000244241507057301700223510ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package test import ( "errors" "fmt" "io" "math/rand" "net" "sync" "time" "github.com/pion/transport/v3/deadline" ) const ( tickWait = 10 * time.Microsecond udpString = "udp" ) var ( errIOTimeout = errors.New("i/o timeout") errBridgeConnClosed = errors.New("bridgeConn closed") errBridgeAlreadyClosed = errors.New("bridge has already been closed") errInverseArrayWithOne = errors.New("inverse requires more than one item in the array") errBadLossChanceRange = errors.New("loss must be < 100 && > 0") ) type bridgeConnAddr int func (a bridgeConnAddr) Network() string { return udpString } func (a bridgeConnAddr) String() string { return fmt.Sprintf("a%d", a) } // bridgeConn is a net.Conn that represents an endpoint of the bridge. type bridgeConn struct { br *Bridge id int closeReq bool closing bool closed bool mutex sync.RWMutex readCh chan []byte lossChance int readDeadline *deadline.Deadline writeDeadline *deadline.Deadline } type netError struct { error timeout, temporary bool } func (e *netError) Timeout() bool { return e.timeout } func (e *netError) Temporary() bool { return e.temporary } // Read reads data, block until the data becomes available. func (conn *bridgeConn) Read(b []byte) (int, error) { select { case <-conn.readDeadline.Done(): return 0, &netError{errIOTimeout, true, true} default: } select { case data, ok := <-conn.readCh: if !ok { return 0, io.EOF } n := copy(b, data) return n, nil case <-conn.readDeadline.Done(): return 0, &netError{errIOTimeout, true, true} } } // Write writes data to the bridge. func (conn *bridgeConn) Write(b []byte) (int, error) { select { case <-conn.writeDeadline.Done(): return 0, &netError{errIOTimeout, true, true} default: } if rand.Intn(100) < conn.lossChance { //nolint:gosec return len(b), nil } if !conn.br.Push(b, conn.id) { return 0, &netError{errBridgeConnClosed, false, false} } return len(b), nil } // Close closes the bridge (releases resources used). func (conn *bridgeConn) Close() error { conn.mutex.Lock() defer conn.mutex.Unlock() if conn.closeReq { return &netError{errBridgeAlreadyClosed, false, false} } conn.closeReq = true conn.closing = true return nil } // LocalAddr is not used. func (conn *bridgeConn) LocalAddr() net.Addr { return bridgeConnAddr(conn.id) } // RemoteAddr is not used. func (conn *bridgeConn) RemoteAddr() net.Addr { return nil } // SetDeadline sets deadline of Read/Write operation. // Setting zero means no deadline. func (conn *bridgeConn) SetDeadline(t time.Time) error { conn.writeDeadline.Set(t) conn.readDeadline.Set(t) return nil } // SetReadDeadline sets deadline of Read operation. // Setting zero means no deadline. func (conn *bridgeConn) SetReadDeadline(t time.Time) error { conn.readDeadline.Set(t) return nil } // SetWriteDeadline sets deadline of Write operation. // Setting zero means no deadline. func (conn *bridgeConn) SetWriteDeadline(t time.Time) error { conn.writeDeadline.Set(t) return nil } func (conn *bridgeConn) isClosed() bool { conn.mutex.RLock() defer conn.mutex.RUnlock() return conn.closed } // Bridge represents a network between the two endpoints. type Bridge struct { mutex sync.RWMutex conn0 *bridgeConn conn1 *bridgeConn queue0to1 [][]byte queue1to0 [][]byte dropNWrites0 int dropNWrites1 int reorderNWrites0 int reorderNWrites1 int stack0 [][]byte stack1 [][]byte filterCB0 func([]byte) bool filterCB1 func([]byte) bool err error // last error } func inverse(s [][]byte) error { if len(s) < 2 { return errInverseArrayWithOne } for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { s[i], s[j] = s[j], s[i] } return nil } // drop n packets from the slice starting from offset. func drop(s [][]byte, offset, n int) [][]byte { if offset+n > len(s) { n = len(s) - offset } return append(s[:offset], s[offset+n:]...) } // NewBridge creates a new bridge with two endpoints. func NewBridge() *Bridge { br := &Bridge{ queue0to1: make([][]byte, 0), queue1to0: make([][]byte, 0), } br.conn0 = &bridgeConn{ br: br, id: 0, readCh: make(chan []byte), readDeadline: deadline.New(), writeDeadline: deadline.New(), } br.conn1 = &bridgeConn{ br: br, id: 1, readCh: make(chan []byte), readDeadline: deadline.New(), writeDeadline: deadline.New(), } return br } // GetConn0 returns an endpoint of the bridge, conn0. func (br *Bridge) GetConn0() net.Conn { return br.conn0 } // GetConn1 returns an endpoint of the bridge, conn1. func (br *Bridge) GetConn1() net.Conn { return br.conn1 } // Len returns number of queued packets. func (br *Bridge) Len(fromID int) int { br.mutex.Lock() defer br.mutex.Unlock() if fromID == 0 { return len(br.queue0to1) } return len(br.queue1to0) } // Push pushes a packet into the specified queue. func (br *Bridge) Push(packet []byte, fromID int) bool { //nolint:gocognit,cyclop data := make([]byte, len(packet)) copy(data, packet) // Push rate should be limited as same as Tick rate. // Otherwise, queue grows too fast on free running Write. time.Sleep(tickWait) br.mutex.Lock() defer br.mutex.Unlock() br.conn0.mutex.Lock() br.conn1.mutex.Lock() closing0 := br.conn0.closing closing1 := br.conn1.closing br.conn1.mutex.Unlock() br.conn0.mutex.Unlock() if closing0 || closing1 { if fromID == 0 && closing0 { return false } if fromID == 1 && closing1 { return false } return true } if fromID == 0 { //nolint:nestif switch { case br.dropNWrites0 > 0: br.dropNWrites0-- // fmt.Printf("br: dropped a packet of size %d (rem: %d for q0)\n", len(d), br.dropNWrites0) // nolint case br.reorderNWrites0 > 0: br.reorderNWrites0-- br.stack0 = append(br.stack0, data) // fmt.Printf("stack0 size: %d\n", len(br.stack0)) // nolint if br.reorderNWrites0 == 0 { if err := inverse(br.stack0); err == nil { // fmt.Printf("stack0 reordered!\n") // nolint br.queue0to1 = append(br.queue0to1, br.stack0...) } else { br.err = err } } case br.filterCB0 != nil && !br.filterCB0(data): // fmt.Printf("br: filtered out a packet of size %d (q0)\n", len(d)) // nolint default: // fmt.Printf("br: routed a packet of size %d (q0)\n", len(d)) // nolint br.queue0to1 = append(br.queue0to1, data) } } else { switch { case br.dropNWrites1 > 0: br.dropNWrites1-- // fmt.Printf("br: dropped a packet of size %d (rem: %d for q1)\n", len(d), br.dropNWrites0) // nolint case br.reorderNWrites1 > 0: br.reorderNWrites1-- br.stack1 = append(br.stack1, data) if br.reorderNWrites1 == 0 { if err := inverse(br.stack1); err != nil { br.err = err } br.queue1to0 = append(br.queue1to0, br.stack1...) } case br.filterCB1 != nil && !br.filterCB1(data): // fmt.Printf("br: filtered out a packet of size %d (q1)\n", len(d)) // nolint default: // fmt.Printf("br: routed a packet of size %d (q1)\n", len(d)) // nolint br.queue1to0 = append(br.queue1to0, data) } } return true } // Reorder inverses the order of packets currently in the specified queue. func (br *Bridge) Reorder(fromID int) error { br.mutex.Lock() defer br.mutex.Unlock() if fromID == 0 { return inverse(br.queue0to1) } return inverse(br.queue1to0) } // Drop drops the specified number of packets from the given offset index // of the specified queue. func (br *Bridge) Drop(fromID, offset, n int) { br.mutex.Lock() defer br.mutex.Unlock() if fromID == 0 { br.queue0to1 = drop(br.queue0to1, offset, n) } else { br.queue1to0 = drop(br.queue1to0, offset, n) } } // DropNextNWrites drops the next n packets that will be written // to the specified queue. func (br *Bridge) DropNextNWrites(fromID, n int) { br.mutex.Lock() defer br.mutex.Unlock() if fromID == 0 { br.dropNWrites0 = n } else { br.dropNWrites1 = n } } // ReorderNextNWrites drops the next n packets that will be written // to the specified queue. func (br *Bridge) ReorderNextNWrites(fromID, n int) { br.mutex.Lock() defer br.mutex.Unlock() if fromID == 0 { br.reorderNWrites0 = n } else { br.reorderNWrites1 = n } } func (br *Bridge) clear() { br.mutex.Lock() defer br.mutex.Unlock() br.queue1to0 = nil br.queue0to1 = nil } // Tick attempts to hand a packet from the queue for each directions, to readers, // if there are waiting on the queue. If there's no reader, it will return // immediately. func (br *Bridge) Tick() int { //nolint:cyclop br.mutex.Lock() defer br.mutex.Unlock() br.conn0.mutex.Lock() if br.conn0.closing && !br.conn0.closed && len(br.queue1to0) == 0 { br.conn0.closed = true close(br.conn0.readCh) } br.conn0.mutex.Unlock() br.conn1.mutex.Lock() if br.conn1.closing && !br.conn1.closed && len(br.queue0to1) == 0 { br.conn1.closed = true close(br.conn1.readCh) } br.conn1.mutex.Unlock() var n int if len(br.queue0to1) > 0 && !br.conn1.isClosed() { select { case br.conn1.readCh <- br.queue0to1[0]: n++ // fmt.Printf("conn1 received data (%d bytes)\n", len(br.queue0to1[0])) // nolint br.queue0to1 = br.queue0to1[1:] default: } } if len(br.queue1to0) > 0 && !br.conn0.isClosed() { select { case br.conn0.readCh <- br.queue1to0[0]: n++ // fmt.Printf("conn0 received data (%d bytes)\n", len(br.queue1to0[0])) // nolint br.queue1to0 = br.queue1to0[1:] default: } } return n } // Process repeats tick() calls until no more outstanding packet in the queues. func (br *Bridge) Process() { for { time.Sleep(tickWait) br.Tick() if br.Len(0) == 0 && br.Len(1) == 0 { break } } } // SetLossChance sets the probability of writes being discard // ( to introduce artificial loss). func (br *Bridge) SetLossChance(chance int) error { if chance > 100 || chance < 0 { return errBadLossChanceRange } //nolint:staticcheck rand.Seed(time.Now().UTC().UnixNano()) br.conn0.lossChance = chance br.conn1.lossChance = chance return nil } // Filter filters (drops) packets based on return value of the given callback. func (br *Bridge) Filter(fromID int, cb func([]byte) bool) { br.mutex.Lock() defer br.mutex.Unlock() if fromID == 0 { br.filterCB0 = cb } else { br.filterCB1 = cb } } golang-github-pion-transport-v3-3.0.8/test/bridge_test.go000066400000000000000000000201101507057301700233740ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package test import ( "fmt" "net" "sync" "testing" "time" "github.com/stretchr/testify/assert" "golang.org/x/net/nettest" ) const ( msg1 = `ADC` msg2 = `DEFG` ) // helper to close both conns. func closeBridge(br *Bridge) error { if err := br.conn0.Close(); err != nil { return err } return br.conn1.Close() } type AsyncResult struct { n int err error msg string } func TestBridge(t *testing.T) { //nolint:gocyclo,cyclop,maintidx tt := TimeOut(30 * time.Second) defer tt.Stop() buf := make([]byte, 256) t.Run("normal", func(t *testing.T) { readRes := make(chan AsyncResult) br := NewBridge() conn0 := br.GetConn0() conn1 := br.GetConn1() assert.Equal(t, "a0", conn0.LocalAddr().String()) assert.Equal(t, "a1", conn1.LocalAddr().String()) assert.Equal(t, "udp", conn0.LocalAddr().Network()) assert.Equal(t, "udp", conn1.LocalAddr().Network()) n, err := conn0.Write([]byte(msg1)) assert.NoError(t, err) assert.Len(t, msg1, n, "unexpected length") go func() { nInner, errInner := conn1.Read(buf) readRes <- AsyncResult{n: nInner, err: errInner} }() br.Process() ar := <-readRes assert.NoError(t, ar.err) assert.Len(t, msg1, ar.n, "unexpected length") assert.NoError(t, closeBridge(br)) }) t.Run("drop 1st packet from conn0", func(t *testing.T) { //nolint:dupl readRes := make(chan AsyncResult) br := NewBridge() conn0 := br.GetConn0() conn1 := br.GetConn1() n, err := conn0.Write([]byte(msg1)) assert.NoError(t, err) assert.Len(t, msg1, n, "unexpected length") n, err = conn0.Write([]byte(msg2)) assert.NoError(t, err) assert.Len(t, msg2, n, "unexpected length") go func() { nInner, errInner := conn1.Read(buf) readRes <- AsyncResult{n: nInner, err: errInner} }() br.Drop(0, 0, 1) br.Process() ar := <-readRes assert.NoError(t, ar.err) assert.Len(t, msg2, ar.n, "unexpected length") assert.NoError(t, closeBridge(br)) }) t.Run("drop 2nd packet from conn0", func(t *testing.T) { //nolint:dupl readRes := make(chan AsyncResult) br := NewBridge() conn0 := br.GetConn0() conn1 := br.GetConn1() n, err := conn0.Write([]byte(msg1)) assert.NoError(t, err) assert.Len(t, msg1, n, "unexpected length") n, err = conn0.Write([]byte(msg2)) assert.NoError(t, err) assert.Len(t, msg2, n, "unexpected length") go func() { nInner, errInner := conn1.Read(buf) readRes <- AsyncResult{n: nInner, err: errInner} }() br.Drop(0, 1, 1) br.Process() ar := <-readRes assert.NoError(t, ar.err) assert.Len(t, msg1, ar.n, "unexpected length") assert.NoError(t, closeBridge(br)) }) t.Run("drop 1st packet from conn1", func(t *testing.T) { //nolint:dupl readRes := make(chan AsyncResult) br := NewBridge() conn0 := br.GetConn0() conn1 := br.GetConn1() n, err := conn1.Write([]byte(msg1)) assert.NoError(t, err) assert.Len(t, msg1, n, "unexpected length") n, err = conn1.Write([]byte(msg2)) assert.NoError(t, err) assert.Len(t, msg2, n, "unexpected length") go func() { nInner, errInner := conn0.Read(buf) readRes <- AsyncResult{n: nInner, err: errInner} }() br.Drop(1, 0, 1) br.Process() ar := <-readRes assert.NoError(t, ar.err) assert.Len(t, msg2, ar.n, "unexpected length") assert.NoError(t, closeBridge(br)) }) t.Run("drop 2nd packet from conn1", func(t *testing.T) { //nolint:dupl readRes := make(chan AsyncResult) br := NewBridge() conn0 := br.GetConn0() conn1 := br.GetConn1() n, err := conn1.Write([]byte(msg1)) assert.NoError(t, err) assert.Len(t, msg1, n, "unexpected length") n, err = conn1.Write([]byte(msg2)) assert.NoError(t, err) assert.Len(t, msg2, n, "unexpected length") go func() { nInner, errInner := conn0.Read(buf) readRes <- AsyncResult{n: nInner, err: errInner} }() br.Drop(1, 1, 1) br.Process() ar := <-readRes assert.NoError(t, ar.err) assert.Len(t, msg1, ar.n, "unexpected length") assert.NoError(t, closeBridge(br)) }) t.Run("reorder packets from conn0", func(t *testing.T) { //nolint:dupl br := NewBridge() conn0 := br.GetConn0() conn1 := br.GetConn1() n, err := conn0.Write([]byte(msg1)) assert.NoError(t, err) assert.Len(t, msg1, n, "unexpected length") n, err = conn0.Write([]byte(msg2)) assert.NoError(t, err) assert.Len(t, msg2, n, "unexpected length") done := make(chan bool) go func() { nInner, errInner := conn1.Read(buf) assert.NoError(t, errInner) assert.Len(t, msg2, nInner, "unexpected length") nInner, errInner = conn1.Read(buf) assert.NoError(t, errInner) assert.Len(t, msg1, nInner, "unexpected length") done <- true }() err = br.Reorder(0) assert.NoError(t, err) br.Process() <-done assert.NoError(t, closeBridge(br)) }) t.Run("reorder packets from conn1", func(t *testing.T) { //nolint:dupl br := NewBridge() conn0 := br.GetConn0() conn1 := br.GetConn1() n, err := conn1.Write([]byte(msg1)) assert.NoError(t, err) assert.Len(t, msg1, n, "unexpected length") n, err = conn1.Write([]byte(msg2)) assert.NoError(t, err) assert.Len(t, msg2, n, "unexpected length") done := make(chan bool) go func() { nInner, errInner := conn0.Read(buf) assert.NoError(t, errInner) assert.Len(t, msg2, nInner, "unexpected length") nInner, errInner = conn0.Read(buf) assert.NoError(t, errInner) assert.Len(t, msg1, nInner, "unexpected length") done <- true }() err = br.Reorder(1) assert.NoError(t, err) br.Process() <-done assert.NoError(t, closeBridge(br)) }) t.Run("inverse error", func(t *testing.T) { q := [][]byte{} q = append(q, []byte("ABC")) assert.Error(t, inverse(q), "inverse should fail if less than 2 pkts") }) t.Run("drop next N packets", func(t *testing.T) { testFrom := func(t *testing.T, fromID int) { t.Helper() readRes := make(chan AsyncResult, 5) br := NewBridge() conn0 := br.GetConn0() conn1 := br.GetConn1() var wg sync.WaitGroup wg.Add(1) var srcConn, dstConn net.Conn if fromID == 0 { br.DropNextNWrites(0, 3) srcConn = conn0 dstConn = conn1 } else { br.DropNextNWrites(1, 3) srcConn = conn1 dstConn = conn0 } go func() { defer wg.Done() for { nInner, errInner := dstConn.Read(buf) if errInner != nil { break } readRes <- AsyncResult{ n: nInner, err: nil, msg: string(buf), } } }() msgs := make([]string, 0) for i := 0; i < 5; i++ { msg := fmt.Sprintf("msg%d", i) msgs = append(msgs, msg) n, err := srcConn.Write([]byte(msg)) assert.NoErrorf(t, err, "Test: %d", fromID) assert.Len(t, msg, n, "[%d] unexpected length", fromID) br.Process() } assert.NoErrorf(t, closeBridge(br), "Test: %d", fromID) br.Process() wg.Wait() assert.Lenf(t, readRes, 2, "[%d] unexpected number of packets", fromID) for i := 0; i < 2; i++ { ar := <-readRes assert.NoErrorf(t, ar.err, "Test: %d", fromID) assert.Len(t, msgs[i+3], ar.n, "[%d] unexpected length", fromID) } } testFrom(t, 0) testFrom(t, 1) }) } type closePropagator struct { *bridgeConn otherEnd *bridgeConn } func (c *closePropagator) Close() error { c.otherEnd.mutex.Lock() c.otherEnd.closing = true c.otherEnd.mutex.Unlock() return c.bridgeConn.Close() } func TestNetTest(t *testing.T) { nettest.TestConn(t, func() (net.Conn, net.Conn, func(), error) { br := NewBridge() conn0 := br.GetConn0().(*bridgeConn) //nolint:forcetypeassert conn1 := br.GetConn1().(*bridgeConn) //nolint:forcetypeassert var wg sync.WaitGroup wg.Add(1) go func() { for { br.Process() if conn0.isClosed() && conn1.isClosed() { wg.Done() return } } }() return &closePropagator{conn0, conn1}, &closePropagator{conn1, conn0}, func() { // RacyRead test leave receive buffer filled. // As net.Conn.Read() should return received data even after Close()-ed, // queue must be cleared explicitly. br.clear() _ = conn0.Close() _ = conn1.Close() // Tick the clock to actually close conns. br.Tick() wg.Wait() }, nil }) } golang-github-pion-transport-v3-3.0.8/test/connctx.go000066400000000000000000000013321507057301700225620ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package test import ( "context" "io" ) type wrappedReader struct { io.Reader } func (r *wrappedReader) ReadContext(_ context.Context, b []byte) (int, error) { return r.Reader.Read(b) } type wrappedWriter struct { io.Writer } func (r *wrappedWriter) WriteContext(_ context.Context, b []byte) (int, error) { return r.Writer.Write(b) } type wrappedReadWriter struct { io.ReadWriter } func (r *wrappedReadWriter) ReadContext(_ context.Context, b []byte) (int, error) { return r.ReadWriter.Read(b) } func (r *wrappedReadWriter) WriteContext(_ context.Context, b []byte) (int, error) { return r.ReadWriter.Write(b) } golang-github-pion-transport-v3-3.0.8/test/rand.go000066400000000000000000000015301507057301700220320ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package test import ( crand "crypto/rand" "errors" "fmt" mrand "math/rand" ) var errRequestTooLargeBuffer = errors.New("requested too large buffer") type randomizer struct { randomness []byte } func initRand() randomizer { // read 1MB of randomness randomness := make([]byte, 1<<20) if _, err := crand.Read(randomness); err != nil { fmt.Println("Failed to initiate randomness:", err) // nolint } return randomizer{ randomness: randomness, } } func (r *randomizer) randBuf(size int) ([]byte, error) { n := len(r.randomness) - size if n < 1 { return nil, fmt.Errorf("%w (%d). max is %d", errRequestTooLargeBuffer, size, len(r.randomness)) } start := mrand.Intn(n) //nolint:gosec return r.randomness[start : start+size], nil } golang-github-pion-transport-v3-3.0.8/test/stress.go000066400000000000000000000053421507057301700224360ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package test import ( "bytes" "context" "errors" "fmt" "io" "sync" "github.com/pion/transport/v3/netctx" ) var errByteSequenceChanged = errors.New("byte sequence changed") // Options represents the configuration of the stress test. type Options struct { MsgSize int MsgCount int } // Stress enables stress testing of a io.ReadWriter. // It checks that packets are received correctly and in order. func Stress(ca io.Writer, cb io.Reader, opt Options) error { return StressContext(context.Background(), &wrappedWriter{ca}, &wrappedReader{cb}, opt) } // StressContext enables stress testing of a io.ReadWriter. // It checks that packets are received correctly and in order. func StressContext(ctx context.Context, ca netctx.Writer, cb netctx.Reader, opt Options) error { bufs := make(chan []byte, opt.MsgCount) errCh := make(chan error) // Write go func() { err := write(ctx, ca, bufs, opt) errCh <- err close(bufs) }() // Read go func() { result := make([]byte, opt.MsgSize) for original := range bufs { err := read(ctx, cb, original, result) if err != nil { errCh <- err } } close(errCh) }() return FlattenErrs(GatherErrs(errCh)) } func read(ctx context.Context, r netctx.Reader, original, result []byte) error { n, err := r.ReadContext(ctx, result) if err != nil { return err } if !bytes.Equal(original, result[:n]) { return fmt.Errorf("%w %#v != %#v", errByteSequenceChanged, original, result) } return nil } // StressDuplex enables duplex stress testing of a io.ReadWriter. // It checks that packets are received correctly and in order. func StressDuplex(ca io.ReadWriter, cb io.ReadWriter, opt Options) error { return StressDuplexContext(context.Background(), &wrappedReadWriter{ca}, &wrappedReadWriter{cb}, opt) } // StressDuplexContext enables duplex stress testing of a io.ReadWriter. // It checks that packets are received correctly and in order. func StressDuplexContext(ctx context.Context, ca netctx.ReadWriter, cb netctx.ReadWriter, opt Options) error { errCh := make(chan error) var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() errCh <- StressContext(ctx, ca, cb, opt) }() go func() { defer wg.Done() errCh <- StressContext(ctx, cb, ca, opt) }() go func() { wg.Wait() close(errCh) }() return FlattenErrs(GatherErrs(errCh)) } func write(ctx context.Context, c netctx.Writer, bufs chan []byte, opt Options) error { randomizer := initRand() for i := 0; i < opt.MsgCount; i++ { buf, err := randomizer.randBuf(opt.MsgSize) if err != nil { return err } bufs <- buf if _, err = c.WriteContext(ctx, buf); err != nil { return err } } return nil } golang-github-pion-transport-v3-3.0.8/test/test.go000066400000000000000000000005351507057301700220710ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package test provides helpers to test the various pion transports implementations. // The tests are standardized around the io.ReadWriteCloser interface. // This package is meant to be used in addition to golang.org/x/net/nettest. package test golang-github-pion-transport-v3-3.0.8/test/test_test.go000066400000000000000000000012641507057301700231300ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package test import ( "io" "net" "testing" "github.com/stretchr/testify/assert" ) func TestStressIOPipe(t *testing.T) { r, w := io.Pipe() opt := Options{ MsgSize: 2048, MsgCount: 100, } assert.NoError(t, Stress(w, r, opt)) } func TestStressDuplexNetPipe(t *testing.T) { ca, cb := net.Pipe() opt := Options{ MsgSize: 2048, MsgCount: 100, } assert.NoError(t, StressDuplex(ca, cb, opt)) } func BenchmarkPipe(b *testing.B) { ca, cb := net.Pipe() b.ResetTimer() opt := Options{ MsgSize: 2048, MsgCount: b.N, } assert.NoError(b, Stress(ca, cb, opt)) } golang-github-pion-transport-v3-3.0.8/test/util.go000066400000000000000000000077641507057301700221020ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package test import ( "errors" "fmt" "os" "runtime" "runtime/pprof" "strings" "testing" "time" ) var errFlattenErrs = errors.New("") // TimeOut is used to panic if a test takes to long. // It will print the current goroutines and panic. // It is meant as an aid in debugging deadlocks. func TimeOut(t time.Duration) *time.Timer { return time.AfterFunc(t, func() { if err := pprof.Lookup("goroutine").WriteTo(os.Stdout, 1); err != nil { fmt.Printf("failed to print goroutines: %v \n", err) // nolint } panic("timeout") // nolint }) } func tryCheckRoutinesLoop(tb testing.TB, failMessage string) { tb.Helper() try := 0 ticker := time.NewTicker(200 * time.Millisecond) defer ticker.Stop() for range ticker.C { runtime.GC() routines := getRoutines() if len(routines) == 0 { return } if try >= 50 { tb.Fatalf("%s: \n%s", failMessage, strings.Join(routines, "\n\n")) // nolint } try++ } } // CheckRoutines is used to check for leaked go-routines. func CheckRoutines(t *testing.T) func() { t.Helper() tryCheckRoutinesLoop(t, "Unexpected routines on test startup") return func() { tryCheckRoutinesLoop(t, "Unexpected routines on test end") } } // CheckRoutinesStrict is used to check for leaked go-routines. // It differs from CheckRoutines in that it has very little tolerance // for lingering goroutines. This is helpful for tests that need // to ensure clean closure of resources. // Checking the state of goroutines exactly is tricky. As users writing // goroutines, we tend to clean up gracefully using some synchronization // pattern. When used correctly, we won't leak goroutines, but we cannot // guarantee *when* the goroutines will end. This is the nature of waiting // on the runtime's goexit1 being called which is the final subroutine // called, which is after any user written code. This small, but possible // chance to have a thread (not goroutine) be preempted before this is // called, can have our goroutine stack be not quite correct yet. The // best we can do is sleep a little bit and try to encourage the runtime // to run that goroutine (G) on the machine (M) it belongs to. func CheckRoutinesStrict(tb testing.TB) func() { tb.Helper() tryCheckRoutinesLoop(tb, "Unexpected routines on test startup") return func() { runtime.Gosched() runtime.GC() routines := getRoutines() if len(routines) == 0 { return } // arbitrarily short choice to allow the runtime to cleanup any // goroutines that really aren't doing anything but haven't yet // completed. time.Sleep(time.Millisecond) runtime.Gosched() runtime.GC() routines = getRoutines() if len(routines) == 0 { return } tb.Fatalf("%s: \n%s", "Unexpected routines on test end", strings.Join(routines, "\n\n")) // nolint } } func getRoutines() []string { buf := make([]byte, 2<<20) buf = buf[:runtime.Stack(buf, true)] return filterRoutines(strings.Split(string(buf), "\n\n")) } func filterRoutines(routines []string) []string { result := []string{} for _, stack := range routines { if stack == "" || // Empty filterRoutineWASM(stack) || // WASM specific exception strings.Contains(stack, "testing.Main(") || // Tests strings.Contains(stack, "testing.(*T).Run(") || // Test run strings.Contains(stack, "test.getRoutines(") { // This routine continue } result = append(result, stack) } return result } // GatherErrs gathers all errors returned by a channel. // It blocks until the channel is closed. func GatherErrs(c chan error) []error { errs := make([]error, 0) for err := range c { errs = append(errs, err) } return errs } // FlattenErrs flattens a slice of errors into a single error. func FlattenErrs(errs []error) error { var errStrings []string for _, err := range errs { if err != nil { errStrings = append(errStrings, err.Error()) } } if len(errStrings) == 0 { return nil } return fmt.Errorf("%w %s", errFlattenErrs, strings.Join(errStrings, "\n")) } golang-github-pion-transport-v3-3.0.8/test/util_nowasm.go000066400000000000000000000003141507057301700234460ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !wasm // +build !wasm package test func filterRoutineWASM(string) bool { return false } golang-github-pion-transport-v3-3.0.8/test/util_test.go000066400000000000000000000021021507057301700231160ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package test import ( "fmt" "testing" "time" "github.com/stretchr/testify/assert" ) func TestCheckRoutines(t *testing.T) { // Limit runtime in case of deadlocks lim := TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := CheckRoutines(t) defer report() go func() { time.Sleep(1 * time.Second) }() } func TestCheckRoutinesStrict(t *testing.T) { mock := &tbMock{TB: t} // Limit runtime in case of deadlocks lim := TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := CheckRoutinesStrict(mock) defer func() { report() assert.NotEmpty(t, mock.fatalfCalled, "expected Fatalf to be called") assert.Contains(t, mock.fatalfCalled[0], "Unexpected routines") }() go func() { time.Sleep(1 * time.Second) }() } type tbMock struct { testing.TB fatalfCalled []string } func (m *tbMock) Fatalf(format string, args ...any) { m.fatalfCalled = append(m.fatalfCalled, fmt.Sprintf(format, args...)) } golang-github-pion-transport-v3-3.0.8/test/util_wasm.go000066400000000000000000000005451507057301700231170ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package test import ( "strings" ) func filterRoutineWASM(stack string) bool { // Nested t.Run on Go 1.14-1.21 and go1.22 WASM have these routines return strings.Contains(stack, "runtime.goexit()") || strings.Contains(stack, "runtime.goexit({})") } golang-github-pion-transport-v3-3.0.8/udp/000077500000000000000000000000001507057301700203715ustar00rootroot00000000000000golang-github-pion-transport-v3-3.0.8/udp/batchconn.go000066400000000000000000000101751507057301700226630ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package udp import ( "io" "net" "runtime" "sync" "sync/atomic" "time" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) // BatchWriter represents conn can write messages in batch. type BatchWriter interface { WriteBatch(ms []ipv4.Message, flags int) (int, error) } // BatchReader represents conn can read messages in batch. type BatchReader interface { ReadBatch(msg []ipv4.Message, flags int) (int, error) } // BatchPacketConn represents conn can read/write messages in batch. type BatchPacketConn interface { BatchWriter BatchReader io.Closer } // BatchConn uses ipv4/v6.NewPacketConn to wrap a net.PacketConn to write/read messages in batch, // only available in linux. In other platform, it will use single Write/Read as same as net.Conn. type BatchConn struct { net.PacketConn batchConn BatchPacketConn batchWriteMutex sync.Mutex batchWriteMessages []ipv4.Message batchWritePos int batchWriteLast time.Time batchWriteSize int batchWriteInterval time.Duration closed int32 } // NewBatchConn creates a *BatchConn from net.PacketConn with batch configs. func NewBatchConn(conn net.PacketConn, batchWriteSize int, batchWriteInterval time.Duration) *BatchConn { bc := &BatchConn{ PacketConn: conn, batchWriteLast: time.Now(), batchWriteInterval: batchWriteInterval, batchWriteSize: batchWriteSize, batchWriteMessages: make([]ipv4.Message, batchWriteSize), } for i := range bc.batchWriteMessages { bc.batchWriteMessages[i].Buffers = [][]byte{make([]byte, sendMTU)} } // batch write only supports linux if runtime.GOOS == "linux" { if pc4 := ipv4.NewPacketConn(conn); pc4 != nil { bc.batchConn = pc4 } else if pc6 := ipv6.NewPacketConn(conn); pc6 != nil { bc.batchConn = pc6 } } if bc.batchConn != nil { go func() { writeTicker := time.NewTicker(batchWriteInterval / 2) defer writeTicker.Stop() for atomic.LoadInt32(&bc.closed) != 1 { <-writeTicker.C bc.batchWriteMutex.Lock() if bc.batchWritePos > 0 && time.Since(bc.batchWriteLast) >= bc.batchWriteInterval { _ = bc.flush() } bc.batchWriteMutex.Unlock() } }() } return bc } // Close batchConn and the underlying PacketConn. func (c *BatchConn) Close() error { atomic.StoreInt32(&c.closed, 1) c.batchWriteMutex.Lock() if c.batchWritePos > 0 { _ = c.flush() } c.batchWriteMutex.Unlock() if c.batchConn != nil { return c.batchConn.Close() } return c.PacketConn.Close() } // WriteTo write message to an UDPAddr, addr should be nil if it is a connected socket. func (c *BatchConn) WriteTo(b []byte, addr net.Addr) (int, error) { if c.batchConn == nil { return c.PacketConn.WriteTo(b, addr) } return c.enqueueMessage(b, addr) } func (c *BatchConn) enqueueMessage(buf []byte, raddr net.Addr) (int, error) { var err error c.batchWriteMutex.Lock() defer c.batchWriteMutex.Unlock() msg := &c.batchWriteMessages[c.batchWritePos] // reset buffers msg.Buffers = msg.Buffers[:1] msg.Buffers[0] = msg.Buffers[0][:cap(msg.Buffers[0])] c.batchWritePos++ if raddr != nil { msg.Addr = raddr } if n := copy(msg.Buffers[0], buf); n < len(buf) { extraBuffer := make([]byte, len(buf)-n) copy(extraBuffer, buf[n:]) msg.Buffers = append(msg.Buffers, extraBuffer) } else { msg.Buffers[0] = msg.Buffers[0][:n] } if c.batchWritePos == c.batchWriteSize { err = c.flush() } return len(buf), err } // ReadBatch reads messages in batch, return length of message readed and error. func (c *BatchConn) ReadBatch(msgs []ipv4.Message, flags int) (int, error) { if c.batchConn == nil { n, addr, err := c.PacketConn.ReadFrom(msgs[0].Buffers[0]) if err == nil { msgs[0].N = n msgs[0].Addr = addr return 1, nil } return 0, err } return c.batchConn.ReadBatch(msgs, flags) } func (c *BatchConn) flush() error { var writeErr error var txN int for txN < c.batchWritePos { n, err := c.batchConn.WriteBatch(c.batchWriteMessages[txN:c.batchWritePos], 0) if err != nil { writeErr = err break } txN += n } c.batchWritePos = 0 c.batchWriteLast = time.Now() return writeErr } golang-github-pion-transport-v3-3.0.8/udp/conn.go000066400000000000000000000224021507057301700216550ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package udp provides a connection-oriented listener over a UDP PacketConn package udp import ( "context" "errors" "net" "sync" "sync/atomic" "time" "github.com/pion/transport/v3/deadline" "github.com/pion/transport/v3/packetio" "golang.org/x/net/ipv4" ) const ( receiveMTU = 8192 sendMTU = 1500 defaultListenBacklog = 128 // same as Linux default ) // Typed errors. var ( ErrClosedListener = errors.New("udp: listener closed") ErrListenQueueExceeded = errors.New("udp: listen queue exceeded") ErrInvalidBatchConfig = errors.New("udp: invalid batch config") ) // listener augments a connection-oriented Listener over a UDP PacketConn. type listener struct { pConn net.PacketConn readBatchSize int accepting atomic.Value // bool acceptCh chan *Conn doneCh chan struct{} doneOnce sync.Once acceptFilter func([]byte) bool connLock sync.Mutex conns map[string]*Conn connWG *sync.WaitGroup readWG sync.WaitGroup errClose atomic.Value // error readDoneCh chan struct{} errRead atomic.Value // error } // Accept waits for and returns the next connection to the listener. func (l *listener) Accept() (net.Conn, error) { select { case c := <-l.acceptCh: l.connWG.Add(1) return c, nil case <-l.readDoneCh: err, _ := l.errRead.Load().(error) return nil, err case <-l.doneCh: return nil, ErrClosedListener } } // Close closes the listener. // Any blocked Accept operations will be unblocked and return errors. func (l *listener) Close() error { var err error l.doneOnce.Do(func() { l.accepting.Store(false) close(l.doneCh) l.connLock.Lock() // Close unaccepted connections lclose: for { select { case c := <-l.acceptCh: close(c.doneCh) delete(l.conns, c.rAddr.String()) default: break lclose } } nConns := len(l.conns) l.connLock.Unlock() l.connWG.Done() if nConns == 0 { // Wait if this is the final connection l.readWG.Wait() if errClose, ok := l.errClose.Load().(error); ok { err = errClose } } else { err = nil } }) return err } // Addr returns the listener's network address. func (l *listener) Addr() net.Addr { return l.pConn.LocalAddr() } // BatchIOConfig indicates config to batch read/write packets, // it will use ReadBatch/WriteBatch to improve throughput for UDP. type BatchIOConfig struct { Enable bool // ReadBatchSize indicates the maximum number of packets to be read in one batch, a batch size less than 2 means // disable read batch. ReadBatchSize int // WriteBatchSize indicates the maximum number of packets to be written in one batch WriteBatchSize int // WriteBatchInterval indicates the maximum interval to wait before writing packets in one batch // small interval will reduce latency/jitter, but increase the io count. WriteBatchInterval time.Duration } // ListenConfig stores options for listening to an address. type ListenConfig struct { // Backlog defines the maximum length of the queue of pending // connections. It is equivalent of the backlog argument of // POSIX listen function. // If a connection request arrives when the queue is full, // the request will be silently discarded, unlike TCP. // Set zero to use default value 128 which is same as Linux default. Backlog int // AcceptFilter determines whether the new conn should be made for // the incoming packet. If not set, any packet creates new conn. AcceptFilter func([]byte) bool // ReadBufferSize sets the size of the operating system's // receive buffer associated with the listener. ReadBufferSize int // WriteBufferSize sets the size of the operating system's // send buffer associated with the connection. WriteBufferSize int Batch BatchIOConfig } // Listen creates a new listener based on the ListenConfig. func (lc *ListenConfig) Listen(network string, laddr *net.UDPAddr) (net.Listener, error) { if lc.Backlog == 0 { lc.Backlog = defaultListenBacklog } if lc.Batch.Enable && (lc.Batch.WriteBatchSize <= 0 || lc.Batch.WriteBatchInterval <= 0) { return nil, ErrInvalidBatchConfig } conn, err := net.ListenUDP(network, laddr) if err != nil { return nil, err } if lc.ReadBufferSize > 0 { _ = conn.SetReadBuffer(lc.ReadBufferSize) } if lc.WriteBufferSize > 0 { _ = conn.SetWriteBuffer(lc.WriteBufferSize) } listnerer := &listener{ pConn: conn, acceptCh: make(chan *Conn, lc.Backlog), conns: make(map[string]*Conn), doneCh: make(chan struct{}), acceptFilter: lc.AcceptFilter, connWG: &sync.WaitGroup{}, readDoneCh: make(chan struct{}), } if lc.Batch.Enable { listnerer.pConn = NewBatchConn(conn, lc.Batch.WriteBatchSize, lc.Batch.WriteBatchInterval) listnerer.readBatchSize = lc.Batch.ReadBatchSize } listnerer.accepting.Store(true) listnerer.connWG.Add(1) listnerer.readWG.Add(2) // wait readLoop and Close execution routine go listnerer.readLoop() go func() { listnerer.connWG.Wait() if err := listnerer.pConn.Close(); err != nil { listnerer.errClose.Store(err) } listnerer.readWG.Done() }() return listnerer, nil } // Listen creates a new listener using default ListenConfig. func Listen(network string, laddr *net.UDPAddr) (net.Listener, error) { return (&ListenConfig{}).Listen(network, laddr) } // readLoop has to tasks: // 1. Dispatching incoming packets to the correct Conn. // It can therefore not be ended until all Conns are closed. // 2. Creating a new Conn when receiving from a new remote. func (l *listener) readLoop() { defer l.readWG.Done() defer close(l.readDoneCh) if br, ok := l.pConn.(BatchReader); ok && l.readBatchSize > 1 { l.readBatch(br) } else { l.read() } } func (l *listener) readBatch(br BatchReader) { msgs := make([]ipv4.Message, l.readBatchSize) for i := range msgs { msg := &msgs[i] msg.Buffers = [][]byte{make([]byte, receiveMTU)} msg.OOB = make([]byte, 40) } for { n, err := br.ReadBatch(msgs, 0) if err != nil { l.errRead.Store(err) return } for i := 0; i < n; i++ { l.dispatchMsg(msgs[i].Addr, msgs[i].Buffers[0][:msgs[i].N]) } } } func (l *listener) read() { buf := make([]byte, receiveMTU) for { n, raddr, err := l.pConn.ReadFrom(buf) if err != nil { l.errRead.Store(err) return } l.dispatchMsg(raddr, buf[:n]) } } func (l *listener) dispatchMsg(addr net.Addr, buf []byte) { conn, ok, err := l.getConn(addr, buf) if err != nil { return } if ok { _, _ = conn.buffer.Write(buf) } } func (l *listener) getConn(raddr net.Addr, buf []byte) (*Conn, bool, error) { l.connLock.Lock() defer l.connLock.Unlock() conn, ok := l.conns[raddr.String()] if !ok { if isAccepting, ok := l.accepting.Load().(bool); !isAccepting || !ok { return nil, false, ErrClosedListener } if l.acceptFilter != nil { if !l.acceptFilter(buf) { return nil, false, nil } } conn = l.newConn(raddr) select { case l.acceptCh <- conn: l.conns[raddr.String()] = conn default: return nil, false, ErrListenQueueExceeded } } return conn, true, nil } // Conn augments a connection-oriented connection over a UDP PacketConn. type Conn struct { listener *listener rAddr net.Addr buffer *packetio.Buffer doneCh chan struct{} doneOnce sync.Once writeDeadline *deadline.Deadline } func (l *listener) newConn(rAddr net.Addr) *Conn { return &Conn{ listener: l, rAddr: rAddr, buffer: packetio.NewBuffer(), doneCh: make(chan struct{}), writeDeadline: deadline.New(), } } // Read reads from c into p. func (c *Conn) Read(p []byte) (int, error) { return c.buffer.Read(p) } // Write writes len(p) bytes from p to the DTLS connection. func (c *Conn) Write(p []byte) (n int, err error) { select { case <-c.writeDeadline.Done(): return 0, context.DeadlineExceeded default: } return c.listener.pConn.WriteTo(p, c.rAddr) } // Close closes the conn and releases any Read calls. func (c *Conn) Close() error { var err error c.doneOnce.Do(func() { c.listener.connWG.Done() close(c.doneCh) c.listener.connLock.Lock() delete(c.listener.conns, c.rAddr.String()) nConns := len(c.listener.conns) c.listener.connLock.Unlock() if isAccepting, ok := c.listener.accepting.Load().(bool); nConns == 0 && !isAccepting && ok { // Wait if this is the final connection c.listener.readWG.Wait() if errClose, ok := c.listener.errClose.Load().(error); ok { err = errClose } } else { err = nil } if errBuf := c.buffer.Close(); errBuf != nil && err == nil { err = errBuf } }) return err } // LocalAddr implements net.Conn.LocalAddr. func (c *Conn) LocalAddr() net.Addr { return c.listener.pConn.LocalAddr() } // RemoteAddr implements net.Conn.RemoteAddr. func (c *Conn) RemoteAddr() net.Addr { return c.rAddr } // SetDeadline implements net.Conn.SetDeadline. func (c *Conn) SetDeadline(t time.Time) error { c.writeDeadline.Set(t) return c.SetReadDeadline(t) } // SetReadDeadline implements net.Conn.SetDeadline. func (c *Conn) SetReadDeadline(t time.Time) error { return c.buffer.SetReadDeadline(t) } // SetWriteDeadline implements net.Conn.SetDeadline. func (c *Conn) SetWriteDeadline(t time.Time) error { c.writeDeadline.Set(t) // Write deadline of underlying connection should not be changed // since the connection can be shared. return nil } golang-github-pion-transport-v3-3.0.8/udp/conn_test.go000066400000000000000000000272001507057301700227150ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package udp import ( "errors" "fmt" "io" "net" "os" "sync" "sync/atomic" "testing" "time" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/assert" ) var errHandshakeFailed = errors.New("handshake failed") // Note: doesn't work since closing isn't propagated to the other side // func TestNetTest(t *testing.T) { // lim := test.TimeOut(time.Minute*1 + time.Second*10) // defer lim.Stop() // // nettest.TestConn(t, func() (c1, c2 net.Conn, stop func(), err error) { // listener, c1, c2, err = pipe() // if err != nil { // return nil, nil, nil, err // } // stop = func() { // c1.Close() // c2.Close() // listener.Close(1 * time.Second) // } // return // }) //} func TestStressDuplex(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() // Run the test stressDuplex(t) } func stressDuplex(t *testing.T) { t.Helper() listener, ca, cb, err := pipe() assert.NoError(t, err) defer func() { err = ca.Close() assert.NoError(t, err) err = cb.Close() assert.NoError(t, err) err = listener.Close() assert.NoError(t, err) }() opt := test.Options{ MsgSize: 2048, MsgCount: 1, // Can't rely on UDP message order in CI } err = test.StressDuplex(ca, cb, opt) assert.NoError(t, err) } func TestListenerCloseTimeout(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() listener, ca, _, err := pipe() assert.NoError(t, err) err = listener.Close() assert.NoError(t, err) // Close client after server closes to cleanup err = ca.Close() assert.NoError(t, err) } func TestListenerCloseUnaccepted(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() const backlog = 2 network, addr := getConfig() listener, err := (&ListenConfig{ Backlog: backlog, }).Listen(network, addr) assert.NoError(t, err) for i := 0; i < backlog; i++ { addr, ok := listener.Addr().(*net.UDPAddr) assert.True(t, ok) conn, err := net.DialUDP(network, nil, addr) assert.NoError(t, err) _, err = conn.Write([]byte{byte(i)}) assert.NoError(t, err) assert.NoError(t, conn.Close()) } time.Sleep(100 * time.Millisecond) // Wait all packets being processed by readLoop // Unaccepted connections must be closed by listener.Close() assert.NoError(t, listener.Close()) } func TestListenerAcceptFilter(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() testCases := map[string]struct { packet []byte accept bool }{ "CreateConn": { packet: []byte{0xAA}, accept: true, }, "Discarded": { packet: []byte{0x00}, accept: false, }, } for name, testCase := range testCases { testCase := testCase t.Run(name, func(t *testing.T) { network, addr := getConfig() listener, err := (&ListenConfig{ AcceptFilter: func(pkt []byte) bool { return pkt[0] == 0xAA }, }).Listen(network, addr) assert.NoError(t, err) var wgAcceptLoop sync.WaitGroup wgAcceptLoop.Add(1) defer func() { assert.NoError(t, listener.Close()) wgAcceptLoop.Wait() }() addr, ok := listener.Addr().(*net.UDPAddr) assert.True(t, ok) conn, err := net.DialUDP(network, nil, addr) assert.NoError(t, err) _, err = conn.Write(testCase.packet) assert.NoError(t, err) defer func() { assert.NoError(t, conn.Close()) }() chAccepted := make(chan struct{}) go func() { defer wgAcceptLoop.Done() conn, aArr := listener.Accept() if aArr != nil { assert.ErrorIs(t, aArr, ErrClosedListener) return } close(chAccepted) assert.NoError(t, conn.Close()) }() var accepted bool select { case <-chAccepted: accepted = true case <-time.After(10 * time.Millisecond): } if testCase.accept { assert.Equal(t, testCase.accept, accepted, "Packet should create new conn") } else { assert.Equal(t, testCase.accept, accepted, "Packet should not create new conn") } }) } } func TestListenerConcurrent(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() const backlog = 2 network, addr := getConfig() listener, err := (&ListenConfig{ Backlog: backlog, }).Listen(network, addr) assert.NoError(t, err) for i := 0; i < backlog+1; i++ { addr, ok := listener.Addr().(*net.UDPAddr) assert.True(t, ok) conn, connErr := net.DialUDP(network, nil, addr) assert.NoError(t, connErr) _, connErr = conn.Write([]byte{byte(i)}) assert.NoError(t, connErr) assert.NoError(t, conn.Close()) } time.Sleep(100 * time.Millisecond) // Wait all packets being processed by readLoop for i := 0; i < backlog; i++ { conn, connErr := listener.Accept() assert.NoError(t, connErr) b := make([]byte, 1) n, connErr := conn.Read(b) assert.NoError(t, connErr) assert.Equalf(t, []byte{byte(i)}, b[:n], "Packet from connection %d is wrong", i) assert.NoError(t, conn.Close()) } var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() conn, connErr := listener.Accept() assert.ErrorIs(t, connErr, ErrClosedListener, "Connection exceeding backlog limit must be discarded") if connErr == nil { _ = conn.Close() } }() time.Sleep(100 * time.Millisecond) // Last Accept should be discarded err = listener.Close() assert.NoError(t, err) wg.Wait() } func pipe() (net.Listener, net.Conn, *net.UDPConn, error) { // Start listening network, addr := getConfig() listener, err := Listen(network, addr) if err != nil { return nil, nil, nil, fmt.Errorf("failed to listen: %w", err) } // Open a connection var dConn *net.UDPConn addr, ok := listener.Addr().(*net.UDPAddr) if !ok { return nil, nil, nil, fmt.Errorf("failed to get listener addr: %w", os.ErrInvalid) } dConn, err = net.DialUDP(network, nil, addr) if err != nil { return nil, nil, nil, fmt.Errorf("failed to dial: %w", err) } // Write to the connection to initiate it handshake := "hello" _, err = dConn.Write([]byte(handshake)) if err != nil { return nil, nil, nil, fmt.Errorf("failed to write to dialed Conn: %w", err) } // Accept the connection var lConn net.Conn lConn, err = listener.Accept() if err != nil { return nil, nil, nil, fmt.Errorf("failed to accept Conn: %w", err) } var n int buf := make([]byte, len(handshake)) if n, err = lConn.Read(buf); err != nil { return nil, nil, nil, fmt.Errorf("failed to read handshake: %w", err) } result := string(buf[:n]) if handshake != result { return nil, nil, nil, fmt.Errorf("%w: %s != %s", errHandshakeFailed, handshake, result) } return listener, lConn, dConn, nil } func getConfig() (string, *net.UDPAddr) { return "udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0} } func TestConnClose(t *testing.T) { //nolint:cyclop lim := test.TimeOut(time.Second * 5) defer lim.Stop() t.Run("Close", func(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() l, ca, cb, err := pipe() assert.NoError(t, err) assert.NoError(t, ca.Close(), "Failed to close A side") assert.NoError(t, cb.Close(), "Failed to close B side") assert.NoError(t, l.Close(), "Failed to close listener") }) t.Run("CloseError1", func(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() listn, ca, cb, err := pipe() assert.NoError(t, err) // Close l.pConn to inject error. list, ok := listn.(*listener) assert.True(t, ok) assert.NoError(t, list.pConn.Close()) assert.NoError(t, ca.Close(), "Failed to close A side") assert.NoError(t, cb.Close(), "Failed to close B side") assert.Error(t, listn.Close(), "Error is not propagated to Listener.Close") }) t.Run("CloseError2", func(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() l, ca, cb, err := pipe() assert.NoError(t, err) // Close l.pConn to inject error. list, ok := l.(*listener) assert.True(t, ok) assert.NoError(t, list.pConn.Close()) assert.NoError(t, cb.Close(), "Failed to close B side") assert.NoError(t, l.Close(), "Failed to close listener") assert.Error(t, ca.Close(), "Error is not propagated to Conn.Close") }) t.Run("CancelRead", func(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 5) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() listn, ca, cb, err := pipe() assert.NoError(t, err) errC := make(chan error, 1) go func() { buf := make([]byte, 1024) // This read will block because we don't write on the other side. // Calling Close must unblock the call. _, err := ca.Read(buf) errC <- err }() assert.NoError(t, ca.Close(), "Failed to close A side") // Main test condition, Read should return // after ca.Close() by closing the buffer. assert.ErrorIs(t, <-errC, io.EOF) assert.NoError(t, cb.Close(), "Failed to close A side") assert.NoError(t, listn.Close(), "Failed to close listener") }) } func TestBatchIO(t *testing.T) { lc := ListenConfig{ Batch: BatchIOConfig{ Enable: true, ReadBatchSize: 10, WriteBatchSize: 3, WriteBatchInterval: 5 * time.Millisecond, }, ReadBufferSize: 64 * 1024, WriteBufferSize: 64 * 1024, } laddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 15678} listener, err := lc.Listen("udp", laddr) assert.NoError(t, err) var serverConnWg sync.WaitGroup serverConnWg.Add(1) go func() { var exit int32 defer func() { defer serverConnWg.Done() atomic.StoreInt32(&exit, 1) }() for { buf := make([]byte, 1400) conn, lerr := listener.Accept() if errors.Is(lerr, ErrClosedListener) { break } assert.NoError(t, lerr) serverConnWg.Add(1) go func() { defer func() { _ = conn.Close() serverConnWg.Done() }() for atomic.LoadInt32(&exit) != 1 { _ = conn.SetReadDeadline(time.Now().Add(time.Second)) n, rerr := conn.Read(buf) if rerr != nil { assert.ErrorContains(t, rerr, "timeout") } else { _, rerr = conn.Write(buf[:n]) assert.NoError(t, rerr) } } }() } }() raddr, _ := listener.Addr().(*net.UDPAddr) // test flush by WriteBatchInterval expired readBuf := make([]byte, 1400) cli, err := net.DialUDP("udp", nil, raddr) assert.NoError(t, err) flushStr := "flushbytimer" _, err = cli.Write([]byte("flushbytimer")) assert.NoError(t, err) n, err := cli.Read(readBuf) assert.NoError(t, err) assert.Equal(t, flushStr, string(readBuf[:n])) wgs := sync.WaitGroup{} cc := 3 wgs.Add(cc) for i := 0; i < cc; i++ { sendStr := fmt.Sprintf("hello %d", i) go func() { defer wgs.Done() buf := make([]byte, 1400) client, err := net.DialUDP("udp", nil, raddr) assert.NoError(t, err) defer func() { _ = client.Close() }() for i := 0; i < 1; i++ { _, err := client.Write([]byte(sendStr)) assert.NoError(t, err) err = client.SetReadDeadline(time.Now().Add(time.Second)) assert.NoError(t, err) n, err := client.Read(buf) assert.NoError(t, err) assert.Equal(t, sendStr, string(buf[:n]), i) } }() } wgs.Wait() _ = listener.Close() serverConnWg.Wait() } golang-github-pion-transport-v3-3.0.8/utils/000077500000000000000000000000001507057301700207415ustar00rootroot00000000000000golang-github-pion-transport-v3-3.0.8/utils/xor/000077500000000000000000000000001507057301700215515ustar00rootroot00000000000000golang-github-pion-transport-v3-3.0.8/utils/xor/xor_arm.go000066400000000000000000000025171507057301700235540ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2022 The Pion community // SPDX-License-Identifier: MIT //go:build !gccgo // +build !gccgo // Package xor provides utility functions used by other Pion // packages. ARM arch. package xor import ( "unsafe" "golang.org/x/sys/cpu" ) const wordSize = int(unsafe.Sizeof(uintptr(0))) // nolint:gosec var hasNEON = cpu.ARM.HasNEON // nolint:gochecknoglobals func isAligned(a *byte) bool { return uintptr(unsafe.Pointer(a))%uintptr(wordSize) == 0 } // XorBytes xors the bytes in a and b. The destination should have enough // space, otherwise xorBytes will panic. Returns the number of bytes xor'd. // //revive:disable-next-line func XorBytes(dst, a, b []byte) int { n := len(a) if len(b) < n { n = len(b) } if n == 0 { return 0 } // make sure dst has enough space _ = dst[n-1] if hasNEON { xorBytesNEON32(&dst[0], &a[0], &b[0], n) } else if isAligned(&dst[0]) && isAligned(&a[0]) && isAligned(&b[0]) { xorBytesARM32(&dst[0], &a[0], &b[0], n) } else { safeXORBytes(dst, a, b, n) } return n } // n needs to be smaller or equal than the length of a and b. func safeXORBytes(dst, a, b []byte, n int) { for i := 0; i < n; i++ { dst[i] = a[i] ^ b[i] } } //go:noescape func xorBytesARM32(dst, a, b *byte, n int) //go:noescape func xorBytesNEON32(dst, a, b *byte, n int) golang-github-pion-transport-v3-3.0.8/utils/xor/xor_arm.s000066400000000000000000000036231507057301700234100ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2022 The Pion community // SPDX-License-Identifier: MIT // go:build !gccgo // +build !gccgo #include "textflag.h" // func xorBytesARM32(dst, a, b *byte, n int) TEXT ·xorBytesARM32(SB), NOSPLIT|NOFRAME, $0 MOVW dst+0(FP), R0 MOVW a+4(FP), R1 MOVW b+8(FP), R2 MOVW n+12(FP), R3 CMP $4, R3 BLT less_than4 loop_4: MOVW.P 4(R1), R4 MOVW.P 4(R2), R5 EOR R4, R5, R5 MOVW.P R5, 4(R0) SUB $4, R3 CMP $4, R3 BGE loop_4 less_than4: CMP $2, R3 BLT less_than2 MOVH.P 2(R1), R4 MOVH.P 2(R2), R5 EOR R4, R5, R5 MOVH.P R5, 2(R0) SUB $2, R3 less_than2: CMP $0, R3 BEQ end MOVB (R1), R4 MOVB (R2), R5 EOR R4, R5, R5 MOVB R5, (R0) end: RET // func xorBytesNEON32(dst, a, b *byte, n int) TEXT ·xorBytesNEON32(SB), NOSPLIT|NOFRAME, $0 MOVW dst+0(FP), R0 MOVW a+4(FP), R1 MOVW b+8(FP), R2 MOVW n+12(FP), R3 CMP $32, R3 BLT less_than32 loop_32: WORD $0xF421020D // vld1.u8 {q0, q1}, [r1]! WORD $0xF422420D // vld1.u8 {q2, q3}, [r2]! WORD $0xF3004154 // veor q2, q0, q2 WORD $0xF3026156 // veor q3, q1, q3 WORD $0xF400420D // vst1.u8 {q2, q3}, [r0]! SUB $32, R3 CMP $32, R3 BGE loop_32 less_than32: CMP $16, R3 BLT less_than16 WORD $0xF4210A0D // vld1.u8 q0, [r1]! WORD $0xF4222A0D // vld1.u8 q1, [r2]! WORD $0xF3002152 // veor q1, q0, q1 WORD $0xF4002A0D // vst1.u8 {q1}, [r0]! SUB $16, R3 less_than16: CMP $8, R3 BLT less_than8 WORD $0xF421070D // vld1.u8 d0, [r1]! WORD $0xF422170D // vld1.u8 d1, [r2]! WORD $0xF3001111 // veor d1, d0, d1 WORD $0xF400170D // vst1.u8 {d1}, [r0]! SUB $8, R3 less_than8: CMP $4, R3 BLT less_than4 MOVW.P 4(R1), R4 MOVW.P 4(R2), R5 EOR R4, R5, R5 MOVW.P R5, 4(R0) SUB $4, R3 less_than4: CMP $2, R3 BLT less_than2 MOVH.P 2(R1), R4 MOVH.P 2(R2), R5 EOR R4, R5, R5 MOVH.P R5, 2(R0) SUB $2, R3 less_than2: CMP $0, R3 BEQ end MOVB (R1), R4 MOVB (R2), R5 EOR R4, R5, R5 MOVB R5, (R0) end: RET golang-github-pion-transport-v3-3.0.8/utils/xor/xor_generic.go000066400000000000000000000007461507057301700244130ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2013 The Go Authors. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause // SPDX-FileCopyrightText: 2024 The Pion community // SPDX-License-Identifier: MIT //go:build go1.20 && !arm && !gccgo // Package xor provides the XorBytes function. package xor import ( "crypto/subtle" ) // XorBytes calls [crypto/suble.XORBytes]. // //revive:disable-next-line func XorBytes(dst, a, b []byte) int { return subtle.XORBytes(dst, a, b) } golang-github-pion-transport-v3-3.0.8/utils/xor/xor_old.go000066400000000000000000000042171507057301700235520ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2013 The Go Authors. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause // SPDX-FileCopyrightText: 2022 The Pion community // SPDX-License-Identifier: MIT //go:build (!go1.20 && !arm) || gccgo // Package xor provides the XorBytes function. // This version is only used on Go up to version 1.19. package xor import ( "runtime" "unsafe" ) const ( wordSize = int(unsafe.Sizeof(uintptr(0))) // nolint:gosec supportsUnaligned = runtime.GOARCH == "386" || runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64" || runtime.GOARCH == "ppc64" || runtime.GOARCH == "ppc64le" || runtime.GOARCH == "s390x" // nolint:gochecknoglobals ) func isAligned(a *byte) bool { return uintptr(unsafe.Pointer(a))%uintptr(wordSize) == 0 } // XorBytes xors the bytes in a and b. The destination should have enough // space, otherwise xorBytes will panic. Returns the number of bytes xor'd. // //revive:disable-next-line func XorBytes(dst, a, b []byte) int { n := len(a) if len(b) < n { n = len(b) } if n == 0 { return 0 } switch { case supportsUnaligned: fastXORBytes(dst, a, b, n) case isAligned(&dst[0]) && isAligned(&a[0]) && isAligned(&b[0]): fastXORBytes(dst, a, b, n) default: safeXORBytes(dst, a, b, n) } return n } // fastXORBytes xors in bulk. It only works on architectures that // support unaligned read/writes. // n needs to be smaller or equal than the length of a and b. func fastXORBytes(dst, a, b []byte, n int) { // Assert dst has enough space _ = dst[n-1] w := n / wordSize if w > 0 { dw := *(*[]uintptr)(unsafe.Pointer(&dst)) // nolint:gosec aw := *(*[]uintptr)(unsafe.Pointer(&a)) // nolint:gosec bw := *(*[]uintptr)(unsafe.Pointer(&b)) // nolint:gosec for i := 0; i < w; i++ { dw[i] = aw[i] ^ bw[i] } } for i := (n - n%wordSize); i < n; i++ { dst[i] = a[i] ^ b[i] } } // n needs to be smaller or equal than the length of a and b. func safeXORBytes(dst, a, b []byte, n int) { for i := 0; i < n; i++ { dst[i] = a[i] ^ b[i] } } golang-github-pion-transport-v3-3.0.8/utils/xor/xor_test.go000066400000000000000000000037621507057301700237570ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2013 The Go Authors. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause package xor import ( "crypto/rand" "fmt" "io" "testing" "github.com/stretchr/testify/assert" ) func TestXOR(t *testing.T) { //nolint:cyclop for j := 1; j <= 1024; j++ { //nolint:varnamelen if testing.Short() && j > 16 { break } for alignP := 0; alignP < 2; alignP++ { for alignQ := 0; alignQ < 2; alignQ++ { for alignD := 0; alignD < 2; alignD++ { p := make([]byte, j)[alignP:] //nolint:varnamelen q := make([]byte, j)[alignQ:] //nolint:varnamelen d0 := make([]byte, j+alignD+1) d0[j+alignD] = 42 d1 := d0[alignD : j+alignD] d2 := make([]byte, j+alignD)[alignD:] _, err := io.ReadFull(rand.Reader, p) assert.NoError(t, err) _, err = io.ReadFull(rand.Reader, q) assert.NoError(t, err) XorBytes(d1, p, q) n := minInt(p, q) for i := 0; i < n; i++ { d2[i] = p[i] ^ q[i] } assert.Equalf(t, d1, d2, "p: %#v, q: %#v", p, q) assert.Equal(t, byte(42), d0[j+alignD], "guard overwritten") } } } } } func minInt(a, b []byte) int { n := len(a) if len(b) < n { n = len(b) } return n } func BenchmarkXORAligned(b *testing.B) { dst := make([]byte, 1<<15) data0 := make([]byte, 1<<15) data1 := make([]byte, 1<<15) sizes := []int64{1 << 3, 1 << 7, 1 << 11, 1 << 15} for _, size := range sizes { b.Run(fmt.Sprintf("%dBytes", size), func(b *testing.B) { s0 := data0[:size] s1 := data1[:size] b.SetBytes(size) for i := 0; i < b.N; i++ { XorBytes(dst, s0, s1) } }) } } func BenchmarkXORUnalignedDst(b *testing.B) { dst := make([]byte, 1<<15+1) data0 := make([]byte, 1<<15) data1 := make([]byte, 1<<15) sizes := []int64{1 << 3, 1 << 7, 1 << 11, 1 << 15} for _, size := range sizes { b.Run(fmt.Sprintf("%dBytes", size), func(b *testing.B) { s0 := data0[:size] s1 := data1[:size] b.SetBytes(size) for i := 0; i < b.N; i++ { XorBytes(dst[1:], s0, s1) } }) } } golang-github-pion-transport-v3-3.0.8/vnet/000077500000000000000000000000001507057301700205555ustar00rootroot00000000000000golang-github-pion-transport-v3-3.0.8/vnet/.gitignore000066400000000000000000000001561507057301700225470ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT *.sw[poe] golang-github-pion-transport-v3-3.0.8/vnet/README.md000066400000000000000000000220101507057301700220270ustar00rootroot00000000000000# vnet A virtual network layer for pion. ## Overview ### Goals * To make NAT traversal tests easy. * To emulate packet impairment at application level for testing. * To monitor packets at specified arbitrary interfaces. ### Features * Configurable virtual LAN and WAN * Virtually hosted ICE servers ### Virtual network components #### Top View ``` ...................................... : Virtual Network (vnet) : : : +-------+ * 1 +----+ +--------+ : | :App |------------>|:Net|--o<-----|:Router | : +-------+ +----+ | | : +-----------+ * 1 +----+ | | : |:STUNServer|-------->|:Net|--o<-----| | : +-----------+ +----+ | | : +-----------+ * 1 +----+ | | : |:TURNServer|-------->|:Net|--o<-----| | : +-----------+ +----+ [1] | | : : 1 | | 1 <> : : +---<>| |<>----+ [2] : : | +--------+ | : To form | *| v 0..1 : a subnet tree | o [3] +-----+ : : | ^ |:NAT | : : | | +-----+ : : +-------+ : ...................................... Note: o: NIC (Network Interface Controller) [1]: Net implements NIC interface. [2]: Root router has no NAT. All child routers have a NAT always. [3]: Router implements NIC interface for accesses from the parent router. ``` #### Net Net provides 3 interfaces: * Configuration API (direct) * Network API via Net (equivalent to net.Xxx()) * Router access via NIC interface ``` (Pion module/app, ICE servers, etc.) +-----------+ | :App | +-----------+ * | | <> 1 v +---------+ 1 * +-----------+ 1 * +-----------+ 1 * +------+ ..| :Router |----+------>o--| :Net |<>------|:Interface |<>------|:Addr | +---------+ | NIC +-----------+ +-----------+ +------+ | <> (transport.Interface) (net.Addr) | | * +-----------+ 1 * +-----------+ 1 * +------+ +------>o--| :Router |<>------|:Interface |<>------|:Addr | NIC +-----------+ +-----------+ +------+ <> (transport.Interface) (net.Addr) ``` > The instance of `Net` will be the one passed around the project. > Net class has public methods for configuration and for application use. ## Implementation ### Design Policy * Each pion package should have config object which has `Net` (of type `transport.Net`) property. - Just like how we distribute `LoggerFactory` throughout the pion project. * DNS => a simple dictionary (global)? * Each Net has routing capability (a goroutine) * Use interface provided net package as much as possible * Routers are connected in a tree structure (no loop is allowed) - To simplify routing - Easy to control / monitor (stats, etc) * Root router has no NAT (== Internet / WAN) * Non-root router has a NAT always * When a Net is instantiated, it will automatically add `lo0` and `eth0` interface, and `lo0` will have one IP address, 127.0.0.1. (this is not used in pion/ice, however) * When a Net is added to a router, the router automatically assign an IP address for `eth0` interface. - For simplicity * User data won't fragment, but optionally drop chunk larger than MTU * IPv6 is not supported ### Basic steps for setting up virtual network 1. Create a root router (WAN) 1. Create child routers and add to its parent (forms a tree, don't create a loop!) 1. Add instances of Net to each routers 1. Call Stop(), or Stop(), on the top router, which propagates all other routers #### Example: WAN with one endpoint (vnet) ```go import ( "net" "github.com/pion/transport" "github.com/pion/transport/vnet" "github.com/pion/logging" ) // Create WAN (a root router). wan, err := vnet.NewRouter(&RouterConfig{ CIDR: "0.0.0.0/0", LoggerFactory: logging.NewDefaultLoggerFactory(), }) // Create a network. // You can specify a static IP for the instance of Net to use. If not specified, // router will assign an IP address that is contained in the router's CIDR. nw := vnet.NewNet(&vnet.NetConfig{ StaticIP: "27.1.2.3", }) // Add the network to the router. // The router will assign an IP address to `nw`. if err = wan.AddNet(nw); err != nil { // handle error } // Start router. // This will start internal goroutine to route packets. // If you set child routers (using AddRouter), the call on the root router // will start the rest of routers for you. if err = wan.Start(); err != nil { // handle error } // // Your application runs here using `nw`. // // Stop the router. // This will stop all internal Go routines in the router tree. // (No need to call Stop() on child routers) if err = wan.Stop(); err != nil { // handle error } ``` #### Example of how to pass around the instance of vnet.Net The instance of vnet.Net wraps a subset of net package to enable operations on the virtual network. Your project must be able to pass the instance to all your routines that do network operation with net package. A typical way is to use a config param to create your instances with the virtual network instance (`nw` in the above example) like this: ```go type AgentConfig struct { : Net: transport.Net, } type Agent struct { : net: transport.Net, } func NetAgent(config *AgentConfig) *Agent { if config.Net == nil { config.Net = vnet.NewNet() } return &Agent { : net: config.Net, } } ``` ```go // a.net is the instance of vnet.Net class func (a *Agent) listenUDP(...) error { conn, err := a.net.ListenPacket(udpString, ...) if err != nil { return nil, err } : } ``` ### Compatibility and Support Status |`net`
(built-in) |`vnet` |Note | |:--- |:--- |:--- | | net.Interfaces() | a.net.Interfaces() | | | net.InterfaceByName() | a.net.InterfaceByName() | | | net.ResolveUDPAddr() | a.net.ResolveUDPAddr() | | | net.ListenPacket() | a.net.ListenPacket() | | | net.ListenUDP() | a.net.ListenUDP() | ListenPacket() is recommended | | net.Listen() | a.net.Listen() | TODO) | | net.ListenTCP() | (not supported) | Listen() would be recommended | | net.Dial() | a.net.Dial() | | | net.DialUDP() | a.net.DialUDP() | | | net.DialTCP() | (not supported) | | | net.Interface | transport.Interface | | | net.PacketConn | (use it as-is) | | | net.UDPConn | transport.UDPConn | | | net.TCPConn | transport.TCPConn | TODO: Use net.Conn in your code | | net.Dialer | transport.Dialer | Use a.net.CreateDialer() to create it.
The use of vnet.Dialer is currently experimental. | > `a.net` is an instance of Net class, and types are defined under the package name `vnet` > Most of other `interface` types in net package can be used as is. > Please post a github issue when other types/methods need to be added to vnet/vnet.Net. ## TODO / Next Step * Implement TCP (TCPConn, Listen) * Support of IPv6 * Write a bunch of examples for building virtual networks. * Add network impairment features (on Router) - Introduce latency / jitter - Packet filtering handler (allow selectively drop packets, etc.) * Add statistics data retrieval - Total number of packets forward by each router - Total number of packet loss - Total number of connection failure (TCP) ## References * [Comparing Simulated Packet Loss and RealWorld Network Congestion](https://www.riverbed.com/document/fpo/WhitePaper-Riverbed-SimulatedPacketLoss.pdf) * [wireguard-go using GVisor's netstack](https://github.com/WireGuard/wireguard-go/tree/master/tun/netstack)golang-github-pion-transport-v3-3.0.8/vnet/chunk.go000066400000000000000000000131061507057301700222150ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "fmt" "net" "strconv" "strings" "sync/atomic" "time" ) type tcpFlag uint8 const ( tcpFIN tcpFlag = 0x01 tcpSYN tcpFlag = 0x02 tcpRST tcpFlag = 0x04 tcpPSH tcpFlag = 0x08 tcpACK tcpFlag = 0x10 ) func (f tcpFlag) String() string { var sa []string if f&tcpFIN != 0 { sa = append(sa, "FIN") } if f&tcpSYN != 0 { sa = append(sa, "SYN") } if f&tcpRST != 0 { sa = append(sa, "RST") } if f&tcpPSH != 0 { sa = append(sa, "PSH") } if f&tcpACK != 0 { sa = append(sa, "ACK") } return strings.Join(sa, "-") } // Generate a base36-encoded unique tag // See: https://play.golang.org/p/0ZaAID1q-HN var assignChunkTag = func() func() string { //nolint:gochecknoglobals var tagCtr uint64 return func() string { n := atomic.AddUint64(&tagCtr, 1) return strconv.FormatUint(n, 36) } }() // Chunk represents a packet passed around in the vnet. type Chunk interface { setTimestamp() time.Time // used by router getTimestamp() time.Time // used by router getSourceIP() net.IP // used by router getDestinationIP() net.IP // used by router setSourceAddr(address string) error // used by nat setDestinationAddr(address string) error // used by nat SourceAddr() net.Addr DestinationAddr() net.Addr UserData() []byte Tag() string Clone() Chunk Network() string // returns "udp" or "tcp" String() string } type chunkIP struct { timestamp time.Time sourceIP net.IP destinationIP net.IP tag string } func (c *chunkIP) setTimestamp() time.Time { c.timestamp = time.Now() return c.timestamp } func (c *chunkIP) getTimestamp() time.Time { return c.timestamp } func (c *chunkIP) getDestinationIP() net.IP { return c.destinationIP } func (c *chunkIP) getSourceIP() net.IP { return c.sourceIP } func (c *chunkIP) Tag() string { return c.tag } type chunkUDP struct { chunkIP sourcePort int destinationPort int userData []byte } func newChunkUDP(srcAddr, dstAddr *net.UDPAddr) *chunkUDP { return &chunkUDP{ chunkIP: chunkIP{ sourceIP: srcAddr.IP, destinationIP: dstAddr.IP, tag: assignChunkTag(), }, sourcePort: srcAddr.Port, destinationPort: dstAddr.Port, } } func (c *chunkUDP) SourceAddr() net.Addr { return &net.UDPAddr{ IP: c.sourceIP, Port: c.sourcePort, } } func (c *chunkUDP) DestinationAddr() net.Addr { return &net.UDPAddr{ IP: c.destinationIP, Port: c.destinationPort, } } func (c *chunkUDP) UserData() []byte { return c.userData } func (c *chunkUDP) Clone() Chunk { var userData []byte if c.userData != nil { userData = make([]byte, len(c.userData)) copy(userData, c.userData) } return &chunkUDP{ chunkIP: chunkIP{ timestamp: c.timestamp, sourceIP: c.sourceIP, destinationIP: c.destinationIP, tag: c.tag, }, sourcePort: c.sourcePort, destinationPort: c.destinationPort, userData: userData, } } func (c *chunkUDP) Network() string { return udp } func (c *chunkUDP) String() string { src := c.SourceAddr() dst := c.DestinationAddr() return fmt.Sprintf("%s chunk %s %s => %s", src.Network(), c.tag, src.String(), dst.String(), ) } func (c *chunkUDP) setSourceAddr(address string) error { addr, err := net.ResolveUDPAddr(udp, address) if err != nil { return err } c.sourceIP = addr.IP c.sourcePort = addr.Port return nil } func (c *chunkUDP) setDestinationAddr(address string) error { addr, err := net.ResolveUDPAddr(udp, address) if err != nil { return err } c.destinationIP = addr.IP c.destinationPort = addr.Port return nil } type chunkTCP struct { chunkIP sourcePort int destinationPort int flags tcpFlag // control bits userData []byte // only with PSH flag // seq uint32 // always starts with 0 // ack uint32 // always starts with 0 } func newChunkTCP(srcAddr, dstAddr *net.TCPAddr, flags tcpFlag) *chunkTCP { return &chunkTCP{ chunkIP: chunkIP{ sourceIP: srcAddr.IP, destinationIP: dstAddr.IP, tag: assignChunkTag(), }, sourcePort: srcAddr.Port, destinationPort: dstAddr.Port, flags: flags, } } func (c *chunkTCP) SourceAddr() net.Addr { return &net.TCPAddr{ IP: c.sourceIP, Port: c.sourcePort, } } func (c *chunkTCP) DestinationAddr() net.Addr { return &net.TCPAddr{ IP: c.destinationIP, Port: c.destinationPort, } } func (c *chunkTCP) UserData() []byte { return c.userData } func (c *chunkTCP) Clone() Chunk { userData := make([]byte, len(c.userData)) copy(userData, c.userData) return &chunkTCP{ chunkIP: chunkIP{ timestamp: c.timestamp, sourceIP: c.sourceIP, destinationIP: c.destinationIP, }, sourcePort: c.sourcePort, destinationPort: c.destinationPort, userData: userData, } } func (c *chunkTCP) Network() string { return "tcp" } func (c *chunkTCP) String() string { src := c.SourceAddr() dst := c.DestinationAddr() return fmt.Sprintf("%s %s chunk %s %s => %s", src.Network(), c.flags.String(), c.tag, src.String(), dst.String(), ) } func (c *chunkTCP) setSourceAddr(address string) error { addr, err := net.ResolveTCPAddr("tcp", address) if err != nil { return err } c.sourceIP = addr.IP c.sourcePort = addr.Port return nil } func (c *chunkTCP) setDestinationAddr(address string) error { addr, err := net.ResolveTCPAddr("tcp", address) if err != nil { return err } c.destinationIP = addr.IP c.destinationPort = addr.Port return nil } golang-github-pion-transport-v3-3.0.8/vnet/chunk_queue.go000066400000000000000000000023611507057301700234220ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "sync" ) type chunkQueue struct { chunks []Chunk maxSize int // 0 or negative value: unlimited maxBytes int // 0 or negative value: unlimited currentBytes int mutex sync.RWMutex } func newChunkQueue(maxSize int, maxBytes int) *chunkQueue { return &chunkQueue{ chunks: []Chunk{}, maxSize: maxSize, maxBytes: maxBytes, currentBytes: 0, mutex: sync.RWMutex{}, } } func (q *chunkQueue) push(c Chunk) bool { q.mutex.Lock() defer q.mutex.Unlock() if q.maxSize > 0 && len(q.chunks) >= q.maxSize { return false // dropped } if q.maxBytes > 0 && q.currentBytes+len(c.UserData()) >= q.maxBytes { return false } q.currentBytes += len(c.UserData()) q.chunks = append(q.chunks, c) return true } func (q *chunkQueue) pop() (Chunk, bool) { q.mutex.Lock() defer q.mutex.Unlock() if len(q.chunks) == 0 { return nil, false } c := q.chunks[0] q.chunks = q.chunks[1:] q.currentBytes -= len(c.UserData()) return c, true } func (q *chunkQueue) peek() Chunk { q.mutex.RLock() defer q.mutex.RUnlock() if len(q.chunks) == 0 { return nil } return q.chunks[0] } golang-github-pion-transport-v3-3.0.8/vnet/chunk_queue_test.go000066400000000000000000000024421507057301700244610ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "net" "testing" "github.com/stretchr/testify/assert" ) func TestChunkQueue(t *testing.T) { chunk := newChunkUDP(&net.UDPAddr{ IP: net.ParseIP("192.188.0.2"), Port: 1234, }, &net.UDPAddr{ IP: net.ParseIP(demoIP), Port: 5678, }) chunk.userData = make([]byte, 1200) var ok bool var queue *chunkQueue var chunk2 Chunk queue = newChunkQueue(0, 0) chunk2 = queue.peek() assert.Nil(t, chunk2, "should return nil") ok = queue.push(chunk) assert.True(t, ok, "should succeed") chunk2, ok = queue.pop() assert.True(t, ok, "should succeed") assert.Equal(t, chunk, chunk2, "should be the same") chunk2, ok = queue.pop() assert.False(t, ok, "should fail") assert.Nil(t, chunk2, "should be nil") queue = newChunkQueue(1, 0) ok = queue.push(chunk) assert.True(t, ok, "should succeed") ok = queue.push(chunk) assert.False(t, ok, "should fail") chunk2 = queue.peek() assert.Equal(t, chunk, chunk2, "should be the same") queue = newChunkQueue(0, 1500) ok = queue.push(chunk) assert.True(t, ok, "should succeed") ok = queue.push(chunk) assert.False(t, ok, "should fail") chunk2 = queue.peek() assert.Equal(t, chunk, chunk2, "should be the same") } golang-github-pion-transport-v3-3.0.8/vnet/chunk_test.go000066400000000000000000000107331507057301700232570ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "net" "strings" "testing" "github.com/pion/logging" "github.com/stretchr/testify/assert" ) func TestTCPFragString(t *testing.T) { f := tcpFIN assert.Equal(t, "FIN", f.String(), "should match") f = tcpSYN assert.Equal(t, "SYN", f.String(), "should match") f = tcpRST assert.Equal(t, "RST", f.String(), "should match") f = tcpPSH assert.Equal(t, "PSH", f.String(), "should match") f = tcpACK assert.Equal(t, "ACK", f.String(), "should match") f = tcpSYN | tcpACK assert.Equal(t, "SYN-ACK", f.String(), "should match") } func TestChunk(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") t.Run("ChunkUDP", func(t *testing.T) { src := &net.UDPAddr{ IP: net.ParseIP("192.168.0.2"), Port: 1234, } dst := &net.UDPAddr{ IP: net.ParseIP(demoIP), Port: 5678, } var chunk Chunk = newChunkUDP(src, dst) str := chunk.String() log.Debugf("chunk: %s", str) assert.Equal(t, udp, chunk.Network(), "should match") assert.True(t, strings.Contains(str, src.Network()), "should include network type") assert.True(t, strings.Contains(str, src.String()), "should include address") assert.True(t, strings.Contains(str, dst.String()), "should include address") assert.True(t, chunk.getSourceIP().Equal(src.IP), "ip should match") assert.True(t, chunk.getDestinationIP().Equal(dst.IP), "ip should match") // Test timestamp ts := chunk.setTimestamp() assert.Equal(t, ts, chunk.getTimestamp(), "timestamp should match") uc := chunk.(*chunkUDP) //nolint:forcetypeassert uc.userData = []byte("Hello") cloned := chunk.Clone().(*chunkUDP) //nolint:forcetypeassert // Test setSourceAddr err := uc.setSourceAddr("2.3.4.5:4000") assert.Nil(t, err, "should succeed") assert.Equal(t, "2.3.4.5:4000", uc.SourceAddr().String()) // Test Tag() assert.True(t, len(uc.tag) > 0, "should not be empty") assert.Equal(t, uc.tag, uc.Tag(), "should match") // Verify cloned chunk was not affected by the changes to original chunk uc.userData[0] = []byte("!")[0] // original: "Hello" -> "Hell!" assert.Equal(t, "Hello", string(cloned.userData), "should match") assert.Equal(t, "192.168.0.2:1234", cloned.SourceAddr().String()) assert.True(t, cloned.getSourceIP().Equal(src.IP), "ip should match") assert.True(t, cloned.getDestinationIP().Equal(dst.IP), "ip should match") }) t.Run("ChunkTCP", func(t *testing.T) { src := &net.TCPAddr{ IP: net.ParseIP("192.168.0.2"), Port: 1234, } dst := &net.TCPAddr{ IP: net.ParseIP(demoIP), Port: 5678, } var chunk Chunk = newChunkTCP(src, dst, tcpSYN) str := chunk.String() log.Debugf("chunk: %s", str) assert.Equal(t, "tcp", chunk.Network(), "should match") assert.True(t, strings.Contains(str, src.Network()), "should include network type") assert.True(t, strings.Contains(str, src.String()), "should include address") assert.True(t, strings.Contains(str, dst.String()), "should include address") assert.True(t, chunk.getSourceIP().Equal(src.IP), "ip should match") assert.True(t, chunk.getDestinationIP().Equal(dst.IP), "ip should match") tcp, ok := chunk.(*chunkTCP) assert.True(t, ok, "type should match") assert.Equal(t, tcp.flags, tcpSYN, "flags should match") // Test timestamp ts := chunk.setTimestamp() assert.Equal(t, ts, chunk.getTimestamp(), "timestamp should match") tc := chunk.(*chunkTCP) //nolint:forcetypeassert tc.userData = []byte("Hello") cloned := chunk.Clone().(*chunkTCP) //nolint:forcetypeassert // Test setSourceAddr err := tc.setSourceAddr("2.3.4.5:4000") assert.Nil(t, err, "should succeed") assert.Equal(t, "2.3.4.5:4000", tc.SourceAddr().String()) // Test Tag() assert.True(t, len(tc.tag) > 0, "should not be empty") assert.Equal(t, tc.tag, tc.Tag(), "should match") // Verify cloned chunk was not affected by the changes to original chunk tc.userData[0] = []byte("!")[0] // original: "Hello" -> "Hell!" assert.Equal(t, "Hello", string(cloned.userData), "should match") assert.Equal(t, "192.168.0.2:1234", cloned.SourceAddr().String()) assert.True(t, cloned.getSourceIP().Equal(src.IP), "ip should match") assert.True(t, cloned.getDestinationIP().Equal(dst.IP), "ip should match") // Test setDestinationAddr err = tc.setDestinationAddr("3.4.5.6:7000") assert.Nil(t, err, "should succeed") assert.Equal(t, "3.4.5.6:7000", tc.DestinationAddr().String()) }) } golang-github-pion-transport-v3-3.0.8/vnet/conn.go000066400000000000000000000204341507057301700220440ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "errors" "fmt" "io" "math" "net" "sync" "time" "github.com/pion/transport/v3" ) const ( maxReadQueueSize = 1024 ) var ( errObsCannotBeNil = errors.New("obs cannot be nil") errUseClosedNetworkConn = errors.New("use of closed network connection") errAddrNotUDPAddr = errors.New("addr is not a net.UDPAddr") errLocAddr = errors.New("something went wrong with locAddr") errAlreadyClosed = errors.New("already closed") errNoRemAddr = errors.New("no remAddr defined") ) // vNet implements this. type connObserver interface { write(c Chunk) error onClosed(addr net.Addr) determineSourceIP(locIP, dstIP net.IP) net.IP } // UDPConn is the implementation of the Conn and PacketConn interfaces for UDP network connections. // compatible with net.PacketConn and net.Conn. type UDPConn struct { locAddr *net.UDPAddr // read-only remAddr *net.UDPAddr // read-only obs connObserver // read-only readCh chan Chunk // thread-safe closed bool // requires mutex mu sync.Mutex // to mutex closed flag readTimer *time.Timer // thread-safe } var _ transport.UDPConn = &UDPConn{} func newUDPConn(locAddr, remAddr *net.UDPAddr, obs connObserver) (*UDPConn, error) { if obs == nil { return nil, errObsCannotBeNil } return &UDPConn{ locAddr: locAddr, remAddr: remAddr, obs: obs, readCh: make(chan Chunk, maxReadQueueSize), readTimer: time.NewTimer(time.Duration(math.MaxInt64)), }, nil } // Close closes the connection. // Any blocked ReadFrom or WriteTo operations will be unblocked and return errors. func (c *UDPConn) Close() error { c.mu.Lock() defer c.mu.Unlock() if c.closed { return errAlreadyClosed } c.closed = true close(c.readCh) c.obs.onClosed(c.locAddr) return nil } // LocalAddr returns the local network address. func (c *UDPConn) LocalAddr() net.Addr { return c.locAddr } // RemoteAddr returns the remote network address. func (c *UDPConn) RemoteAddr() net.Addr { return c.remAddr } // SetDeadline sets the read and write deadlines associated // with the connection. It is equivalent to calling both // SetReadDeadline and SetWriteDeadline. // // A deadline is an absolute time after which I/O operations // fail with a timeout (see type Error) instead of // blocking. The deadline applies to all future and pending // I/O, not just the immediately following call to ReadFrom or // WriteTo. After a deadline has been exceeded, the connection // can be refreshed by setting a deadline in the future. // // An idle timeout can be implemented by repeatedly extending // the deadline after successful ReadFrom or WriteTo calls. // // A zero value for t means I/O operations will not time out. func (c *UDPConn) SetDeadline(t time.Time) error { return c.SetReadDeadline(t) } // SetReadDeadline sets the deadline for future ReadFrom calls // and any currently-blocked ReadFrom call. // A zero value for t means ReadFrom will not time out. func (c *UDPConn) SetReadDeadline(t time.Time) error { var d time.Duration if t.IsZero() { d = time.Duration(math.MaxInt64) } else { d = time.Until(t) } c.readTimer.Reset(d) return nil } // SetWriteDeadline sets the deadline for future WriteTo calls // and any currently-blocked WriteTo call. // Even if write times out, it may return n > 0, indicating that // some of the data was successfully written. // A zero value for t means WriteTo will not time out. func (c *UDPConn) SetWriteDeadline(time.Time) error { // Write never blocks. return nil } // Read reads data from the connection. // Read can be made to time out and return an Error with Timeout() == true // after a fixed time limit; see SetDeadline and SetReadDeadline. func (c *UDPConn) Read(b []byte) (int, error) { n, _, err := c.ReadFrom(b) return n, err } // ReadFrom reads a packet from the connection, // copying the payload into p. It returns the number of // bytes copied into p and the return address that // was on the packet. // It returns the number of bytes read (0 <= n <= len(p)) // and any error encountered. Callers should always process // the n > 0 bytes returned before considering the error err. // ReadFrom can be made to time out and return // an Error with Timeout() == true after a fixed time limit; // see SetDeadline and SetReadDeadline. func (c *UDPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { loop: for { select { case chunk, ok := <-c.readCh: if !ok { break loop } var err error n := copy(p, chunk.UserData()) addr := chunk.SourceAddr() if n < len(chunk.UserData()) { err = io.ErrShortBuffer } if c.remAddr != nil { if addr.String() != c.remAddr.String() { break // discard (shouldn't happen) } } return n, addr, err case <-c.readTimer.C: return 0, nil, &net.OpError{ Op: "read", Net: c.locAddr.Network(), Addr: c.locAddr, Err: newTimeoutError("i/o timeout"), } } } return 0, nil, &net.OpError{ Op: "read", Net: c.locAddr.Network(), Addr: c.locAddr, Err: errUseClosedNetworkConn, } } // ReadFromUDP acts like ReadFrom but returns a UDPAddr. func (c *UDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) { n, addr, err := c.ReadFrom(b) udpAddr, ok := addr.(*net.UDPAddr) if !ok { return -1, nil, fmt.Errorf("%w: %s", transport.ErrNotUDPAddress, addr) } return n, udpAddr, err } // ReadMsgUDP reads a message from c, copying the payload into b and // the associated out-of-band data into oob. It returns the number of // bytes copied into b, the number of bytes copied into oob, the flags // that were set on the message and the source address of the message. // // The packages golang.org/x/net/ipv4 and golang.org/x/net/ipv6 can be // used to manipulate IP-level socket options in oob. func (c *UDPConn) ReadMsgUDP([]byte, []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) { return -1, -1, -1, nil, transport.ErrNotSupported } // Write writes data to the connection. // Write can be made to time out and return an Error with Timeout() == true // after a fixed time limit; see SetDeadline and SetWriteDeadline. func (c *UDPConn) Write(b []byte) (int, error) { if c.remAddr == nil { return 0, errNoRemAddr } return c.WriteTo(b, c.remAddr) } // WriteTo writes a packet with payload to addr. // WriteTo can be made to time out and return // an Error with Timeout() == true after a fixed time limit; // see SetDeadline and SetWriteDeadline. // On packet-oriented connections, write timeouts are rare. func (c *UDPConn) WriteTo(payload []byte, addr net.Addr) (n int, err error) { dstAddr, ok := addr.(*net.UDPAddr) if !ok { return 0, errAddrNotUDPAddr } srcIP := c.obs.determineSourceIP(c.locAddr.IP, dstAddr.IP) if srcIP == nil { return 0, errLocAddr } srcAddr := &net.UDPAddr{ IP: srcIP, Port: c.locAddr.Port, } chunk := newChunkUDP(srcAddr, dstAddr) chunk.userData = make([]byte, len(payload)) copy(chunk.userData, payload) if err := c.obs.write(chunk); err != nil { return 0, err } return len(payload), nil } // WriteToUDP acts like WriteTo but takes a UDPAddr. func (c *UDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { return c.WriteTo(b, addr) } // WriteMsgUDP writes a message to addr via c if c isn't connected, or // to c's remote address if c is connected (in which case addr must be // nil). The payload is copied from b and the associated out-of-band // data is copied from oob. It returns the number of payload and // out-of-band bytes written. // // The packages golang.org/x/net/ipv4 and golang.org/x/net/ipv6 can be // used to manipulate IP-level socket options in oob. func (c *UDPConn) WriteMsgUDP([]byte, []byte, *net.UDPAddr) (n, oobn int, err error) { return -1, -1, transport.ErrNotSupported } // SetReadBuffer sets the size of the operating system's // receive buffer associated with the connection. func (c *UDPConn) SetReadBuffer(int) error { return transport.ErrNotSupported } // SetWriteBuffer sets the size of the operating system's // transmit buffer associated with the connection. func (c *UDPConn) SetWriteBuffer(int) error { return transport.ErrNotSupported } func (c *UDPConn) onInboundChunk(chunk Chunk) { c.mu.Lock() defer c.mu.Unlock() if c.closed { return } select { case c.readCh <- chunk: default: } } golang-github-pion-transport-v3-3.0.8/vnet/conn_map.go000066400000000000000000000055221507057301700227020ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "errors" "net" "sync" ) var ( errAddressAlreadyInUse = errors.New("address already in use") errNoSuchUDPConn = errors.New("no such UDPConn") errCannotRemoveUnspecifiedIP = errors.New("cannot remove unspecified IP by the specified IP") ) type udpConnMap struct { portMap map[int][]*UDPConn mutex sync.RWMutex } func newUDPConnMap() *udpConnMap { return &udpConnMap{ portMap: map[int][]*UDPConn{}, } } func (m *udpConnMap) insert(conn *UDPConn) error { m.mutex.Lock() defer m.mutex.Unlock() udpAddr := conn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert // check if the port has a listener conns, ok := m.portMap[udpAddr.Port] if ok { if udpAddr.IP.IsUnspecified() { return errAddressAlreadyInUse } for _, conn := range conns { laddr := conn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert if laddr.IP.IsUnspecified() || laddr.IP.Equal(udpAddr.IP) { return errAddressAlreadyInUse } } conns = append(conns, conn) } else { conns = []*UDPConn{conn} } m.portMap[udpAddr.Port] = conns return nil } func (m *udpConnMap) find(addr net.Addr) (*UDPConn, bool) { m.mutex.Lock() // could be RLock, but we have delete() op defer m.mutex.Unlock() udpAddr := addr.(*net.UDPAddr) //nolint:forcetypeassert if conns, ok := m.portMap[udpAddr.Port]; ok { if udpAddr.IP.IsUnspecified() { // pick the first one appears in the iteration if len(conns) == 0 { // This can't happen! delete(m.portMap, udpAddr.Port) return nil, false } return conns[0], true } for _, conn := range conns { laddr := conn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert if laddr.IP.IsUnspecified() || laddr.IP.Equal(udpAddr.IP) { return conn, ok } } } return nil, false } func (m *udpConnMap) delete(addr net.Addr) error { m.mutex.Lock() defer m.mutex.Unlock() udpAddr := addr.(*net.UDPAddr) //nolint:forcetypeassert conns, ok := m.portMap[udpAddr.Port] if !ok { return errNoSuchUDPConn } if udpAddr.IP.IsUnspecified() { // remove all from this port delete(m.portMap, udpAddr.Port) return nil } newConns := []*UDPConn{} for _, conn := range conns { laddr := conn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert if laddr.IP.IsUnspecified() { // This can't happen! return errCannotRemoveUnspecifiedIP } if laddr.IP.Equal(udpAddr.IP) { continue } newConns = append(newConns, conn) } if len(newConns) == 0 { delete(m.portMap, udpAddr.Port) } else { m.portMap[udpAddr.Port] = newConns } return nil } // size returns the number of UDPConns (UDP listeners). func (m *udpConnMap) size() int { m.mutex.RLock() defer m.mutex.RUnlock() n := 0 for _, conns := range m.portMap { n += len(conns) } return n } golang-github-pion-transport-v3-3.0.8/vnet/conn_map_test.go000066400000000000000000000157261507057301700237500ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "net" "testing" "github.com/stretchr/testify/assert" ) type myConnObserver struct{} func (obs *myConnObserver) write(Chunk) error { return nil } func (obs *myConnObserver) onClosed(net.Addr) { } func (obs *myConnObserver) determineSourceIP(net.IP, net.IP) net.IP { return net.IP{} } func TestUDPConnMap(t *testing.T) { // log := logging.NewDefaultLoggerFactory().NewLogger("test") t.Run("insert an UDPConn and remove it", func(t *testing.T) { connMap := newUDPConnMap() obs := &myConnObserver{} connIn, err := newUDPConn(&net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: 1234, }, nil, obs) assert.NoError(t, err, "should succeed") err = connMap.insert(connIn) assert.NoError(t, err, "should succeed") connOut, ok := connMap.find(connIn.LocalAddr()) assert.True(t, ok, "should succeed") assert.Equal(t, connIn, connOut, "should match") assert.Equal(t, 1, len(connMap.portMap), "should match") err = connMap.delete(connIn.LocalAddr()) assert.NoError(t, err, "should succeed") assert.Empty(t, connMap.portMap, "should match") err = connMap.delete(connIn.LocalAddr()) assert.Error(t, err, "should fail") }) t.Run("insert an UDPConn on 0.0.0.0 and remove it", func(t *testing.T) { connMap := newUDPConnMap() obs := &myConnObserver{} connIn, err := newUDPConn(&net.UDPAddr{ IP: net.ParseIP("0.0.0.0"), Port: 1234, }, nil, obs) assert.NoError(t, err, "should succeed") err = connMap.insert(connIn) assert.NoError(t, err, "should succeed") connOut, ok := connMap.find(connIn.LocalAddr()) assert.True(t, ok, "should succeed") assert.Equal(t, connIn, connOut, "should match") assert.Equal(t, 1, len(connMap.portMap), "should match") err = connMap.delete(connIn.LocalAddr()) assert.NoError(t, err, "should succeed") err = connMap.delete(connIn.LocalAddr()) assert.Error(t, err, "should fail") }) t.Run("find UDPConn on 0.0.0.0 by specified IP", func(t *testing.T) { connMap := newUDPConnMap() obs := &myConnObserver{} connIn, err := newUDPConn(&net.UDPAddr{ IP: net.ParseIP("0.0.0.0"), Port: 1234, }, nil, obs) assert.NoError(t, err, "should succeed") err = connMap.insert(connIn) assert.NoError(t, err, "should succeed") connOut, ok := connMap.find(&net.UDPAddr{ IP: net.ParseIP("192.168.0.1"), Port: 1234, }) assert.True(t, ok, "should succeed") assert.Equal(t, connIn, connOut, "should match") assert.Equal(t, 1, len(connMap.portMap), "should match") }) t.Run("insert many IPs with the same port", func(t *testing.T) { connMap := newUDPConnMap() obs := &myConnObserver{} connIn1, err := newUDPConn(&net.UDPAddr{ IP: net.ParseIP("10.1.2.1"), Port: 5678, }, nil, obs) assert.NoError(t, err, "should succeed") err = connMap.insert(connIn1) assert.NoError(t, err, "should succeed") connIn2, err := newUDPConn(&net.UDPAddr{ IP: net.ParseIP("10.1.2.2"), Port: 5678, }, nil, obs) assert.NoError(t, err, "should succeed") err = connMap.insert(connIn2) assert.NoError(t, err, "should succeed") connOut1, ok := connMap.find(&net.UDPAddr{ IP: net.ParseIP("10.1.2.1"), Port: 5678, }) assert.True(t, ok, "should succeed") assert.Equal(t, connIn1, connOut1, "should match") connOut2, ok := connMap.find(&net.UDPAddr{ IP: net.ParseIP("10.1.2.2"), Port: 5678, }) assert.True(t, ok, "should succeed") assert.Equal(t, connIn2, connOut2, "should match") assert.Equal(t, 1, len(connMap.portMap), "should match") }) t.Run("already in-use when inserting 0.0.0.0", func(t *testing.T) { connMap := newUDPConnMap() obs := &myConnObserver{} connIn1, err := newUDPConn(&net.UDPAddr{ IP: net.ParseIP("10.1.2.1"), Port: 5678, }, nil, obs) assert.NoError(t, err, "should succeed") err = connMap.insert(connIn1) assert.NoError(t, err, "should succeed") connIn2, err := newUDPConn(&net.UDPAddr{ IP: net.ParseIP("0.0.0.0"), Port: 5678, }, nil, obs) assert.NoError(t, err, "should succeed") err = connMap.insert(connIn2) assert.Error(t, err, "should fail") }) t.Run("already in-use when inserting a specified IP", func(t *testing.T) { connMap := newUDPConnMap() obs := &myConnObserver{} connIn1, err := newUDPConn(&net.UDPAddr{ IP: net.ParseIP("0.0.0.0"), Port: 5678, }, nil, obs) assert.NoError(t, err, "should succeed") err = connMap.insert(connIn1) assert.NoError(t, err, "should succeed") connIn2, err := newUDPConn(&net.UDPAddr{ IP: net.ParseIP("192.168.0.1"), Port: 5678, }, nil, obs) assert.NoError(t, err, "should succeed") err = connMap.insert(connIn2) assert.Error(t, err, "should fail") }) t.Run("already in-use when inserting the same specified IP", func(t *testing.T) { connMap := newUDPConnMap() obs := &myConnObserver{} connIn1, err := newUDPConn(&net.UDPAddr{ IP: net.ParseIP("192.168.0.1"), Port: 5678, }, nil, obs) assert.NoError(t, err, "should succeed") err = connMap.insert(connIn1) assert.NoError(t, err, "should succeed") connIn2, err := newUDPConn(&net.UDPAddr{ IP: net.ParseIP("192.168.0.1"), Port: 5678, }, nil, obs) assert.NoError(t, err, "should succeed") err = connMap.insert(connIn2) assert.Error(t, err, "should fail") }) t.Run("find failure 1", func(t *testing.T) { connMap := newUDPConnMap() obs := &myConnObserver{} connIn, err := newUDPConn(&net.UDPAddr{ IP: net.ParseIP("192.168.0.1"), Port: 5678, }, nil, obs) assert.NoError(t, err, "should succeed") err = connMap.insert(connIn) assert.NoError(t, err, "should succeed") _, ok := connMap.find(&net.UDPAddr{ IP: net.ParseIP("192.168.0.2"), Port: 5678, }) assert.False(t, ok, "should fail") }) t.Run("find failure 2", func(t *testing.T) { connMap := newUDPConnMap() obs := &myConnObserver{} connIn, err := newUDPConn(&net.UDPAddr{ IP: net.ParseIP("192.168.0.1"), Port: 5678, }, nil, obs) assert.NoError(t, err, "should succeed") err = connMap.insert(connIn) assert.NoError(t, err, "should succeed") _, ok := connMap.find(&net.UDPAddr{ IP: net.ParseIP("192.168.0.1"), Port: 1234, }) assert.False(t, ok, "should fail") }) t.Run("insert two UDPConns on the same port, then remove them", func(t *testing.T) { connMap := newUDPConnMap() obs := &myConnObserver{} connIn1, err := newUDPConn(&net.UDPAddr{ IP: net.ParseIP("192.168.0.1"), Port: 5678, }, nil, obs) assert.NoError(t, err, "should succeed") err = connMap.insert(connIn1) assert.NoError(t, err, "should succeed") connIn2, err := newUDPConn(&net.UDPAddr{ IP: net.ParseIP("192.168.0.2"), Port: 5678, }, nil, obs) assert.NoError(t, err, "should succeed") err = connMap.insert(connIn2) assert.NoError(t, err, "should succeed") err = connMap.delete(connIn1.LocalAddr()) assert.NoError(t, err, "should succeed") err = connMap.delete(connIn2.LocalAddr()) assert.NoError(t, err, "should succeed") }) } golang-github-pion-transport-v3-3.0.8/vnet/conn_test.go000066400000000000000000000143161507057301700231050ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "errors" "net" "sync/atomic" "testing" "time" "github.com/pion/logging" "github.com/stretchr/testify/assert" ) var errFailedToCovertToChuckUDP = errors.New("failed to convert chunk to chunkUDP") type dummyObserver struct { onWrite func(Chunk) error onOnClosed func(net.Addr) } func (o *dummyObserver) write(c Chunk) error { return o.onWrite(c) } func (o *dummyObserver) onClosed(addr net.Addr) { o.onOnClosed(addr) } func (o *dummyObserver) determineSourceIP(locIP, _ net.IP) net.IP { return locIP } func TestUDPConn(t *testing.T) { //nolint:cyclop,maintidx log := logging.NewDefaultLoggerFactory().NewLogger("test") t.Run("WriteTo ReadFrom", func(t *testing.T) { var nClosed int32 var conn *UDPConn var err error data := []byte("Hello") srcAddr := &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: 1234, } dstAddr := &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: 5678, } obs := &dummyObserver{ onWrite: func(c Chunk) error { uc, ok := c.(*chunkUDP) if !ok { return errFailedToCovertToChuckUDP } chunk := newChunkUDP(uc.DestinationAddr().(*net.UDPAddr), uc.SourceAddr().(*net.UDPAddr)) //nolint:forcetypeassert chunk.userData = make([]byte, len(uc.userData)) copy(chunk.userData, uc.userData) conn.readCh <- chunk // echo back return nil }, onOnClosed: func(net.Addr) { atomic.AddInt32(&nClosed, 1) }, } conn, err = newUDPConn(srcAddr, nil, obs) assert.NoError(t, err, "should succeed") rcvdCh := make(chan struct{}) doneCh := make(chan struct{}) go func() { buf := make([]byte, 1500) for { n, addr, err2 := conn.ReadFrom(buf) if err2 != nil { log.Debug("conn closed. exiting the read loop") break } log.Debug("read data") assert.Equal(t, len(data), n, "should match") assert.Equal(t, string(data), string(data), "should match") assert.Equal(t, dstAddr.String(), addr.String(), "should match") rcvdCh <- struct{}{} } close(doneCh) }() var n int n, err = conn.WriteTo(data, dstAddr) if !assert.Nil(t, err, "should succeed") { return } assert.Equal(t, len(data), n, "should match") loop: for { select { case <-rcvdCh: log.Debug("closing conn..") err2 := conn.Close() assert.Nil(t, err2, "should succeed") case <-doneCh: break loop } } assert.Equal(t, int32(1), atomic.LoadInt32(&nClosed), "should be closed once") }) t.Run("Write Read", func(t *testing.T) { var nClosed int32 var conn *UDPConn var err error data := []byte("Hello") srcAddr := &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: 1234, } dstAddr := &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: 5678, } obs := &dummyObserver{ onWrite: func(c Chunk) error { uc, ok := c.(*chunkUDP) if !ok { return errFailedToCovertToChuckUDP } //nolint:forcetypeassert chunk := newChunkUDP( uc.DestinationAddr().(*net.UDPAddr), uc.SourceAddr().(*net.UDPAddr), ) chunk.userData = make([]byte, len(uc.userData)) copy(chunk.userData, uc.userData) conn.readCh <- chunk // echo back return nil }, onOnClosed: func(net.Addr) { atomic.AddInt32(&nClosed, 1) }, } conn, err = newUDPConn(srcAddr, nil, obs) assert.NoError(t, err, "should succeed") conn.remAddr = dstAddr rcvdCh := make(chan struct{}) doneCh := make(chan struct{}) go func() { buf := make([]byte, 1500) for { n, err2 := conn.Read(buf) if err2 != nil { log.Debug("conn closed. exiting the read loop") break } log.Debug("read data") assert.Equal(t, len(data), n, "should match") assert.Equal(t, string(data), string(data), "should match") rcvdCh <- struct{}{} } close(doneCh) }() var n int n, err = conn.Write(data) if !assert.Nil(t, err, "should succeed") { return } assert.Equal(t, len(data), n, "should match") loop: for { select { case <-rcvdCh: log.Debug("closing conn..") err = conn.Close() assert.Nil(t, err, "should succeed") case <-doneCh: break loop } } assert.Equal(t, int32(1), atomic.LoadInt32(&nClosed), "should be closed once") }) deadlineTest := func(t *testing.T, readOnly bool) { t.Helper() var nClosed int32 var conn *UDPConn var err error srcAddr := &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: 1234, } obs := &dummyObserver{ onOnClosed: func(net.Addr) { atomic.AddInt32(&nClosed, 1) }, } conn, err = newUDPConn(srcAddr, nil, obs) assert.NoError(t, err, "should succeed") doneCh := make(chan struct{}) if readOnly { err = conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond)) } else { err = conn.SetDeadline(time.Now().Add(50 * time.Millisecond)) } assert.Nil(t, err, "should succeed") go func() { buf := make([]byte, 1500) _, _, err := conn.ReadFrom(buf) assert.NotNil(t, err, "should return error") var ne *net.OpError if errors.As(err, &ne) { assert.True(t, ne.Timeout(), "should be a timeout") } else { assert.True(t, false, "should be an net.OpError") } assert.Nil(t, conn.Close(), "should succeed") close(doneCh) }() <-doneCh assert.Equal(t, int32(1), atomic.LoadInt32(&nClosed), "should be closed once") } t.Run("SetReadDeadline", func(t *testing.T) { deadlineTest(t, true) }) t.Run("SetDeadline", func(t *testing.T) { deadlineTest(t, false) }) t.Run("Inbound during close", func(t *testing.T) { var conn *UDPConn var err error srcAddr := &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: 1234, } obs := &dummyObserver{ onOnClosed: func(net.Addr) {}, } for i := 0; i < 1000; i++ { // nolint:staticcheck // (false positive detection) conn, err = newUDPConn(srcAddr, nil, obs) assert.NoError(t, err, "should succeed") chDone := make(chan struct{}) go func() { time.Sleep(20 * time.Millisecond) assert.NoError(t, conn.Close()) close(chDone) }() tick := time.NewTicker(10 * time.Millisecond) for { defer tick.Stop() select { case <-chDone: return case <-tick.C: conn.onInboundChunk(nil) } } } }) } golang-github-pion-transport-v3-3.0.8/vnet/delay_filter.go000066400000000000000000000067141507057301700235570ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "context" "sync" "sync/atomic" "time" ) // DelayFilter delays inbound packets by the given delay. Automatically starts // processing when created and runs until Close() is called. type DelayFilter struct { NIC delay atomic.Int64 // atomic field - stores time.Duration as int64 push chan struct{} queue *chunkQueue done chan struct{} wg sync.WaitGroup } type timedChunk struct { Chunk deadline time.Time } // NewDelayFilter creates and starts a new DelayFilter with the given nic and delay. func NewDelayFilter(nic NIC, delay time.Duration) (*DelayFilter, error) { delayFilter := &DelayFilter{ NIC: nic, push: make(chan struct{}), queue: newChunkQueue(0, 0), done: make(chan struct{}), } delayFilter.delay.Store(int64(delay)) // Start processing automatically delayFilter.wg.Add(1) go delayFilter.run() return delayFilter, nil } // SetDelay atomically updates the delay. func (f *DelayFilter) SetDelay(newDelay time.Duration) { f.delay.Store(int64(newDelay)) } func (f *DelayFilter) getDelay() time.Duration { return time.Duration(f.delay.Load()) } func (f *DelayFilter) onInboundChunk(c Chunk) { f.queue.push(timedChunk{ Chunk: c, deadline: time.Now().Add(f.getDelay()), }) f.push <- struct{}{} } // run processes the delayed packets queue until Close() is called. func (f *DelayFilter) run() { defer f.wg.Done() timer := time.NewTimer(0) defer timer.Stop() for { select { case <-f.done: f.drainRemainingPackets() return case <-f.push: f.updateTimerForNextPacket(timer) case now := <-timer.C: f.processReadyPackets(now) f.scheduleNextPacketTimer(timer) } } } // drainRemainingPackets sends all remaining packets immediately during shutdown. func (f *DelayFilter) drainRemainingPackets() { for { next, ok := f.queue.pop() if !ok { break } if chunk, ok := next.(timedChunk); ok { f.NIC.onInboundChunk(chunk.Chunk) } } } // updateTimerForNextPacket updates the timer when a new packet arrives. func (f *DelayFilter) updateTimerForNextPacket(timer *time.Timer) { next := f.queue.peek() if next != nil { if chunk, ok := next.(timedChunk); ok { if !timer.Stop() { <-timer.C } timer.Reset(time.Until(chunk.deadline)) } } } // processReadyPackets processes all packets that are ready to be sent. func (f *DelayFilter) processReadyPackets(now time.Time) { for { next := f.queue.peek() if next == nil { break } if chunk, ok := next.(timedChunk); ok && chunk.deadline.Before(now) { _, _ = f.queue.pop() // We already have the item from peek() f.NIC.onInboundChunk(chunk.Chunk) } else { break } } } // scheduleNextPacketTimer schedules the timer for the next packet to be processed. func (f *DelayFilter) scheduleNextPacketTimer(timer *time.Timer) { next := f.queue.peek() if next == nil { timer.Reset(time.Minute) // Long timeout when queue is empty } else if chunk, ok := next.(timedChunk); ok { timer.Reset(time.Until(chunk.deadline)) } } // Run is provided for backward compatibility. The DelayFilter now starts // automatically when created, so this method is a no-op. func (f *DelayFilter) Run(_ context.Context) { // DelayFilter now starts automatically in NewDelayFilter, so this is a no-op } // Close stops the DelayFilter and waits for graceful shutdown. func (f *DelayFilter) Close() error { close(f.done) f.wg.Wait() return nil } golang-github-pion-transport-v3-3.0.8/vnet/delay_filter_test.go000066400000000000000000000110441507057301700246060ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "testing" "time" "github.com/stretchr/testify/assert" ) type TimestampedChunk struct { ts time.Time c Chunk } func initTest(t *testing.T) (*DelayFilter, chan TimestampedChunk) { t.Helper() nic := newMockNIC(t) delayFilter, err := NewDelayFilter(nic, 0) if !assert.NoError(t, err, "should succeed") { return nil, nil } receiveCh := make(chan TimestampedChunk) nic.mockOnInboundChunk = func(c Chunk) { receivedAt := time.Now() receiveCh <- TimestampedChunk{ ts: receivedAt, c: c, } } return delayFilter, receiveCh } func scheduleOnePacketAtATime( t *testing.T, delayFilter *DelayFilter, receiveCh chan TimestampedChunk, delay time.Duration, nrPackets int, ) { t.Helper() delayFilter.SetDelay(delay) lastNr := -1 for i := 0; i < nrPackets; i++ { sent := time.Now() delayFilter.onInboundChunk(&chunkUDP{ chunkIP: chunkIP{timestamp: sent}, userData: []byte{byte(i)}, }) select { case chunk := <-receiveCh: nr := int(chunk.c.UserData()[0]) assert.Greater(t, nr, lastNr) lastNr = nr assert.Greater(t, chunk.ts.Sub(sent), delay) // Use generous timing tolerance for CI environments with high system load // and virtualization overhead. Function call overhead from DelayFilter // refactoring also contributes to timing variability. assert.Less(t, chunk.ts.Sub(sent), delay+200*time.Millisecond) case <-time.After(time.Second): assert.Fail(t, "expected to receive next chunk") } } } func scheduleManyPackets( t *testing.T, delayFilter *DelayFilter, receiveCh chan TimestampedChunk, delay time.Duration, nrPackets int, //nolint:unparam ) { t.Helper() delayFilter.SetDelay(delay) sent := time.Now() for i := 0; i < nrPackets; i++ { delayFilter.onInboundChunk(&chunkUDP{ chunkIP: chunkIP{timestamp: sent}, userData: []byte{byte(i)}, }) } // receive nrPackets chunks with a minimum delay for i := 0; i < nrPackets; i++ { select { case chunk := <-receiveCh: nr := int(chunk.c.UserData()[0]) assert.Equal(t, i, nr) assert.Greater(t, chunk.ts.Sub(sent), delay) assert.Less(t, chunk.ts.Sub(sent), delay+200*time.Millisecond) case <-time.After(time.Second): assert.Fail(t, "expected to receive next chunk") } } } func TestDelayFilter(t *testing.T) { t.Run("schedulesOnePacketAtATime", func(t *testing.T) { delayFilter, receiveCh := initTest(t) if delayFilter == nil { return } scheduleOnePacketAtATime(t, delayFilter, receiveCh, 10*time.Millisecond, 100) assert.NoError(t, delayFilter.Close()) }) t.Run("schedulesSubsequentManyPackets", func(t *testing.T) { delayFilter, receiveCh := initTest(t) if delayFilter == nil { return } scheduleManyPackets(t, delayFilter, receiveCh, 10*time.Millisecond, 100) assert.NoError(t, delayFilter.Close()) }) t.Run("scheduleIncreasingDelayOnePacketAtATime", func(t *testing.T) { delayFilter, receiveCh := initTest(t) if delayFilter == nil { return } scheduleOnePacketAtATime(t, delayFilter, receiveCh, 10*time.Millisecond, 10) scheduleOnePacketAtATime(t, delayFilter, receiveCh, 50*time.Millisecond, 10) scheduleOnePacketAtATime(t, delayFilter, receiveCh, 100*time.Millisecond, 10) assert.NoError(t, delayFilter.Close()) }) t.Run("scheduleDecreasingDelayOnePacketAtATime", func(t *testing.T) { delayFilter, receiveCh := initTest(t) if delayFilter == nil { return } scheduleOnePacketAtATime(t, delayFilter, receiveCh, 100*time.Millisecond, 10) scheduleOnePacketAtATime(t, delayFilter, receiveCh, 50*time.Millisecond, 10) scheduleOnePacketAtATime(t, delayFilter, receiveCh, 10*time.Millisecond, 10) assert.NoError(t, delayFilter.Close()) }) t.Run("scheduleIncreasingDelayManyPackets", func(t *testing.T) { delayFilter, receiveCh := initTest(t) if delayFilter == nil { return } scheduleManyPackets(t, delayFilter, receiveCh, 10*time.Millisecond, 100) scheduleManyPackets(t, delayFilter, receiveCh, 50*time.Millisecond, 100) scheduleManyPackets(t, delayFilter, receiveCh, 100*time.Millisecond, 100) assert.NoError(t, delayFilter.Close()) }) t.Run("scheduleDecreasingDelayManyPackets", func(t *testing.T) { delayFilter, receiveCh := initTest(t) if delayFilter == nil { return } scheduleManyPackets(t, delayFilter, receiveCh, 100*time.Millisecond, 100) scheduleManyPackets(t, delayFilter, receiveCh, 50*time.Millisecond, 100) scheduleManyPackets(t, delayFilter, receiveCh, 10*time.Millisecond, 100) assert.NoError(t, delayFilter.Close()) }) } golang-github-pion-transport-v3-3.0.8/vnet/errors.go000066400000000000000000000005401507057301700224170ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet type timeoutError struct { msg string } func newTimeoutError(msg string) error { return &timeoutError{ msg: msg, } } func (e *timeoutError) Error() string { return e.msg } func (e *timeoutError) Timeout() bool { return true } golang-github-pion-transport-v3-3.0.8/vnet/loss_filter.go000066400000000000000000000145021507057301700234330ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "errors" "math/rand" "sync" "time" ) // Static errors for better error handling. var ( ErrInvalidChance = errors.New("chance must be between 0 and 100 inclusive") ErrInvalidShuffleBlockSize = errors.New("shuffleBlockSize must be greater than 0") ) type LossFilterHandler interface { shouldDrop() bool setLossRate(chance int, resetImmediately bool) } // LossFilter is a wrapper around NICs, that drops some of the packets passed to // onInboundChunk. type LossFilter struct { NIC LossFilterHandler } // RandomLossHandler drops packets randomly with a probability determined by the chance parameter. type RandomLossHandler struct { chance int mutex sync.RWMutex } // NewRandomLossHandler creates a new RandomLossHandler with the given drop chance. func NewRandomLossHandler(chance int) (*RandomLossHandler, error) { if !validateChance(chance) { return nil, ErrInvalidChance } return &RandomLossHandler{ chance: chance, }, nil } func (r *RandomLossHandler) shouldDrop() bool { r.mutex.RLock() chance := r.chance r.mutex.RUnlock() return rand.Intn(100) < chance //nolint:gosec } func (r *RandomLossHandler) setLossRate(chance int, _ bool) { r.mutex.Lock() defer r.mutex.Unlock() r.chance = chance } // RandomShuffleLossHandler drops packets with a deterministic probability for every 100 packets // That is, for every 100 packets, it guarantees that the number of packets dropped is equal to the chance parameter. type RandomShuffleLossHandler struct { blockIdx int shuffledBlock []bool currentChance int pendingChance int mutex sync.Mutex } // NewRandomShuffleLossHandler creates a new RandomShuffleLossHandler with the given drop chance and shuffle block size. // The default shuffle block size should be 100. func NewRandomShuffleLossHandler(chance int, shuffleBlockSize int) (*RandomShuffleLossHandler, error) { if !validateChance(chance) { return nil, ErrInvalidChance } if shuffleBlockSize < 1 { return nil, ErrInvalidShuffleBlockSize } filter := RandomShuffleLossHandler{ shuffledBlock: make([]bool, shuffleBlockSize), blockIdx: 0, currentChance: chance, pendingChance: chance, } for i := 0; i < filter.currentChance; i++ { filter.shuffledBlock[i] = true } filter.shuffleBlock() return &filter, nil } func (r *RandomShuffleLossHandler) setLossRate(chance int, resetImmediately bool) { r.mutex.Lock() defer r.mutex.Unlock() r.pendingChance = chance if resetImmediately { r.shuffleBlock() } } func (r *RandomShuffleLossHandler) shuffleBlock() { for idx := 0; idx < len(r.shuffledBlock); idx++ { switch { case r.pendingChance == r.currentChance: goto shuffleComplete case r.pendingChance > r.currentChance && !r.shuffledBlock[idx]: r.shuffledBlock[idx] = true r.currentChance++ case r.pendingChance < r.currentChance && r.shuffledBlock[idx]: r.shuffledBlock[idx] = false r.currentChance-- } } shuffleComplete: rand.Shuffle(len(r.shuffledBlock), func(i, j int) { r.shuffledBlock[i], r.shuffledBlock[j] = r.shuffledBlock[j], r.shuffledBlock[i] }) r.blockIdx = 0 } func (r *RandomShuffleLossHandler) shouldDrop() bool { r.mutex.Lock() defer r.mutex.Unlock() if r.blockIdx == len(r.shuffledBlock) { r.shuffleBlock() } res := r.shuffledBlock[r.blockIdx] r.blockIdx++ return res } // LossFilterOption represents a configuration option for LossFilter creation. type LossFilterOption func(nic NIC, chance int) (LossFilterHandler, error) // WithLossHandler sets a custom loss handler for the LossFilter. func WithLossHandler(handler LossFilterHandler) LossFilterOption { return func(_ NIC, chance int) (LossFilterHandler, error) { // Set the chance on the provided handler handler.setLossRate(chance, false) return handler, nil } } // WithShuffleLossHandler creates a LossFilter with a RandomShuffleLossHandler // with the specified block size for deterministic packet loss distribution. func WithShuffleLossHandler(blockSize int) LossFilterOption { return func(_ NIC, chance int) (LossFilterHandler, error) { return NewRandomShuffleLossHandler(chance, blockSize) } } // NewLossFilter creates a new LossFilter that drops every packet with a // probability of chance/100 using the default RandomLossHandler. // This maintains backward compatibility with the original API. func NewLossFilter(nic NIC, chance int) (*LossFilter, error) { return NewLossFilterWithOptions(nic, chance) } // NewLossFilterWithOptions creates a new LossFilter that drops every packet with a // probability of chance/100. You can provide custom options to override the // default behavior. This follows the Pion options pattern for extensibility. func NewLossFilterWithOptions(nic NIC, chance int, options ...LossFilterOption) (*LossFilter, error) { if !validateChance(chance) { return nil, ErrInvalidChance } var lossHandler LossFilterHandler var err error // If options are provided, use the first one to create the handler if len(options) > 0 { lossHandler, err = options[0](nic, chance) if err != nil { return nil, err } } else { // Create default handler lossHandler, err = NewRandomLossHandler(chance) if err != nil { return nil, err } } lossFilter := &LossFilter{ NIC: nic, LossFilterHandler: lossHandler, } //nolint:staticcheck rand.Seed(time.Now().UTC().UnixNano()) return lossFilter, nil } func (f *LossFilter) onInboundChunk(c Chunk) { if f.LossFilterHandler.shouldDrop() { return } f.NIC.onInboundChunk(c) } // SetLossRate sets the loss rate for the loss filter. // The chance parameter is an integer out of 100. // The resetImmediately parameter is a boolean that indicates whether to reset the loss rate immediately. // If resetImmediately is true, the loss rate will be reset immediately. // If resetImmediately is false, the loss rate will be reset after the next shuffle for RandomShuffleLossHandler // Note that for random loss handler, the loss rate will be reset immediately // regardless of the resetImmediately parameter. func (f *LossFilter) SetLossRate(chance int, resetImmediately bool) error { if !validateChance(chance) { return ErrInvalidChance } f.LossFilterHandler.setLossRate(chance, resetImmediately) return nil } func validateChance(chance int) bool { return chance >= 0 && chance <= 100 } golang-github-pion-transport-v3-3.0.8/vnet/loss_filter_test.go000066400000000000000000000157631507057301700245040ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "net" "testing" "github.com/pion/transport/v3" "github.com/stretchr/testify/assert" ) type mockNIC struct { mockGetInterface func(ifName string) (*transport.Interface, error) mockOnInboundChunk func(c Chunk) mockGetStaticIPs func() []net.IP mockSetRouter func(r *Router) error } func (n *mockNIC) getInterface(ifName string) (*transport.Interface, error) { return n.mockGetInterface(ifName) } func (n *mockNIC) onInboundChunk(c Chunk) { n.mockOnInboundChunk(c) } func (n *mockNIC) getStaticIPs() []net.IP { return n.mockGetStaticIPs() } func (n *mockNIC) setRouter(r *Router) error { return n.mockSetRouter(r) } func newMockNIC(t *testing.T) *mockNIC { t.Helper() return &mockNIC{ mockGetInterface: func(string) (*transport.Interface, error) { assert.Fail(t, "unexpected call to mockGetInterface") return nil, nil }, mockOnInboundChunk: func(Chunk) { assert.Fail(t, "unexpected call to mockOnInboundChunk") }, mockGetStaticIPs: func() []net.IP { assert.Fail(t, "unexpected call to mockGetStaticIPs") return nil }, mockSetRouter: func(*Router) error { assert.Fail(t, "unexpected call to mockSetRouter") return nil }, } } func TestLossFilterFullLoss(t *testing.T) { mnic := newMockNIC(t) lossFilter, err := NewLossFilter(mnic, 100) if !assert.NoError(t, err, "should succeed") { return } lossFilter.onInboundChunk(&chunkUDP{}) } func TestLossFilterNoLoss(t *testing.T) { mnic := newMockNIC(t) lossFilter, err := NewLossFilter(mnic, 0) if !assert.NoError(t, err, "should succeed") { return } packets := 100 received := 0 mnic.mockOnInboundChunk = func(Chunk) { received++ } for i := 0; i < packets; i++ { lossFilter.onInboundChunk(&chunkUDP{}) } assert.Equal(t, packets, received) } func TestLossFilterSomeLoss(t *testing.T) { mnic := newMockNIC(t) lossFilter, err := NewLossFilter(mnic, 50) if !assert.NoError(t, err, "should succeed") { return } packets := 1000 received := 0 mnic.mockOnInboundChunk = func(Chunk) { received++ } for i := 0; i < packets; i++ { lossFilter.onInboundChunk(&chunkUDP{}) } // One of the following could technically fail, but very unlikely assert.Less(t, 0, received) assert.Greater(t, packets, received) } func TestLossFilterLossRateChangeRandomShuffleHandler(t *testing.T) { mnic := newMockNIC(t) lossHandler, err := NewRandomShuffleLossHandler(10, 100) if !assert.NoError(t, err, "should succeed") { return } lossFilter, err := NewLossFilterWithOptions(mnic, 0, WithLossHandler(lossHandler)) if !assert.NoError(t, err, "should succeed") { return } packets := 100 received := 0 mnic.mockOnInboundChunk = func(Chunk) { received++ } for i := 0; i < packets; i++ { lossFilter.onInboundChunk(&chunkUDP{}) } assert.Equal(t, 90, received) err = lossFilter.SetLossRate(50, true) if !assert.NoError(t, err, "should succeed") { return } received = 0 for i := 0; i < packets; i++ { lossFilter.onInboundChunk(&chunkUDP{}) } assert.Equal(t, 50, received) err = lossFilter.SetLossRate(99, true) if !assert.NoError(t, err, "should succeed") { return } received = 0 for i := 0; i < packets; i++ { lossFilter.onInboundChunk(&chunkUDP{}) } assert.Equal(t, 1, received) } func TestLossFilterImmediateLossRateChangeRandomShuffleHandler(t *testing.T) { mnic := newMockNIC(t) lossHandler, err := NewRandomShuffleLossHandler(10, 100) if !assert.NoError(t, err, "should succeed") { return } lossFilter, err := NewLossFilterWithOptions(mnic, 0, WithLossHandler(lossHandler)) if !assert.NoError(t, err, "should succeed") { return } packets := 100 received := 0 mnic.mockOnInboundChunk = func(Chunk) { received++ } // send 50 dummy packets to partially fill shuffle block for i := 0; i < 50; i++ { lossFilter.onInboundChunk(&chunkUDP{}) } // should trigger an immediate shuffle that sets the loss rate to 50% err = lossFilter.SetLossRate(50, true) if !assert.NoError(t, err, "should succeed") { return } received = 0 for i := 0; i < packets; i++ { lossFilter.onInboundChunk(&chunkUDP{}) } assert.Equal(t, 50, received) } func TestLossFilterNonImmediateLossRateChangeRandomShuffleHandler(t *testing.T) { mnic := newMockNIC(t) lossHandler, err := NewRandomShuffleLossHandler(10, 100) if !assert.NoError(t, err, "should succeed") { return } lossFilter, err := NewLossFilterWithOptions(mnic, 0, WithLossHandler(lossHandler)) if !assert.NoError(t, err, "should succeed") { return } received := 0 mnic.mockOnInboundChunk = func(Chunk) { received++ } // send 50 dummy packets to partially fill shuffle block for i := 0; i < 50; i++ { lossFilter.onInboundChunk(&chunkUDP{}) } _ = lossFilter.SetLossRate(100, false) // the loss rate should not be changed until the shuffle block is full for i := 0; i < 50; i++ { lossFilter.onInboundChunk(&chunkUDP{}) } assert.Equal(t, 90, received) received = 0 // the new loss rate should be applied to this block for i := 0; i < 100; i++ { lossFilter.onInboundChunk(&chunkUDP{}) } assert.Equal(t, 0, received) } func TestLossFilterOptionsPattern(t *testing.T) { t.Run("WithLossHandler option", func(t *testing.T) { mnic := newMockNIC(t) customHandler, err := NewRandomShuffleLossHandler(10, 100) if !assert.NoError(t, err, "should succeed") { return } // Using options pattern lossFilter, err := NewLossFilterWithOptions(mnic, 50, WithLossHandler(customHandler)) if !assert.NoError(t, err, "should succeed") { return } // Test that the custom handler is used packets := 100 received := 0 mnic.mockOnInboundChunk = func(Chunk) { received++ } for i := 0; i < packets; i++ { lossFilter.onInboundChunk(&chunkUDP{}) } // Should use the custom handler's behavior (10% loss from handler creation) // not the 50% from NewLossFilterWithOptions chance parameter assert.Equal(t, 90, received) }) t.Run("WithShuffleLossHandler option", func(t *testing.T) { mnic := newMockNIC(t) // Using options pattern with shuffle handler lossFilter, err := NewLossFilterWithOptions(mnic, 25, WithShuffleLossHandler(100)) if !assert.NoError(t, err, "should succeed") { return } packets := 100 received := 0 mnic.mockOnInboundChunk = func(Chunk) { received++ } for i := 0; i < packets; i++ { lossFilter.onInboundChunk(&chunkUDP{}) } // Should use shuffle handler behavior with 25% loss rate assert.Equal(t, 75, received) }) t.Run("Backward compatibility - no options", func(t *testing.T) { mnic := newMockNIC(t) // Old API should still work lossFilter, err := NewLossFilter(mnic, 20) if !assert.NoError(t, err, "should succeed") { return } packets := 1000 received := 0 mnic.mockOnInboundChunk = func(Chunk) { received++ } for i := 0; i < packets; i++ { lossFilter.onInboundChunk(&chunkUDP{}) } // Should work as before with random loss handler assert.Less(t, 0, received) assert.Greater(t, packets, received) }) } golang-github-pion-transport-v3-3.0.8/vnet/nat.go000066400000000000000000000240571507057301700216760ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "errors" "fmt" "net" "sync" "time" "github.com/pion/logging" ) var ( errNATRequriesMapping = errors.New("1:1 NAT requires more than one mapping") errMismatchLengthIP = errors.New("length mismtach between mappedIPs and localIPs") errNonUDPTranslationNotSupported = errors.New("non-udp translation is not supported yet") errNoAssociatedLocalAddress = errors.New("no associated local address") errNoNATBindingFound = errors.New("no NAT binding found") errHasNoPermission = errors.New("has no permission") ) // EndpointDependencyType defines a type of behavioral dependendency on the // remote endpoint's IP address or port number. This is used for the two // kinds of behaviors: // - Port mapping behavior // - Filtering behavior // // See: https://tools.ietf.org/html/rfc4787 type EndpointDependencyType uint8 const ( // EndpointIndependent means the behavior is independent of the endpoint's address or port. EndpointIndependent EndpointDependencyType = iota // EndpointAddrDependent means the behavior is dependent on the endpoint's address. EndpointAddrDependent // EndpointAddrPortDependent means the behavior is dependent on the endpoint's address and port. EndpointAddrPortDependent ) // NATMode defines basic behavior of the NAT. type NATMode uint8 const ( // NATModeNormal means the NAT behaves as a standard NAPT (RFC 2663). NATModeNormal NATMode = iota // NATModeNAT1To1 exhibits 1:1 DNAT where the external IP address is statically mapped to // a specific local IP address with port number is preserved always between them. // When this mode is selected, MappingBehavior, FilteringBehavior, PortPreservation and // MappingLifeTime of NATType are ignored. NATModeNAT1To1 ) const ( defaultNATMappingLifeTime = 30 * time.Second ) // NATType has a set of parameters that define the behavior of NAT. type NATType struct { Mode NATMode MappingBehavior EndpointDependencyType FilteringBehavior EndpointDependencyType Hairpinning bool // Not implemented yet PortPreservation bool // Not implemented yet MappingLifeTime time.Duration } type natConfig struct { name string natType NATType mappedIPs []net.IP // mapped IPv4 localIPs []net.IP // local IPv4, required only when the mode is NATModeNAT1To1 loggerFactory logging.LoggerFactory } type mapping struct { proto string // "udp" or "tcp" local string // ":" mapped string // ":" bound string // key: "[[:]]" filters map[string]struct{} // key: "[[:]]" expires time.Time // time to expire } type networkAddressTranslator struct { name string natType NATType mappedIPs []net.IP // mapped IPv4 localIPs []net.IP // local IPv4, required only when the mode is NATModeNAT1To1 outboundMap map[string]*mapping // key: "::[:remote-ip[:remote-port]] inboundMap map[string]*mapping // key: "::" udpPortCounter int mutex sync.RWMutex log logging.LeveledLogger } func newNAT(config *natConfig) (*networkAddressTranslator, error) { natType := config.natType if natType.Mode == NATModeNAT1To1 { // 1:1 NAT behavior natType.MappingBehavior = EndpointIndependent natType.FilteringBehavior = EndpointIndependent natType.PortPreservation = true natType.MappingLifeTime = 0 if len(config.mappedIPs) == 0 { return nil, errNATRequriesMapping } if len(config.mappedIPs) != len(config.localIPs) { return nil, errMismatchLengthIP } } else { // Normal (NAPT) behavior natType.Mode = NATModeNormal if natType.MappingLifeTime == 0 { natType.MappingLifeTime = defaultNATMappingLifeTime } } return &networkAddressTranslator{ name: config.name, natType: natType, mappedIPs: config.mappedIPs, localIPs: config.localIPs, outboundMap: map[string]*mapping{}, inboundMap: map[string]*mapping{}, log: config.loggerFactory.NewLogger("vnet"), }, nil } func (n *networkAddressTranslator) getPairedMappedIP(locIP net.IP) net.IP { for i, ip := range n.localIPs { if ip.Equal(locIP) { return n.mappedIPs[i] } } return nil } func (n *networkAddressTranslator) getPairedLocalIP(mappedIP net.IP) net.IP { for i, ip := range n.mappedIPs { if ip.Equal(mappedIP) { return n.localIPs[i] } } return nil } func (n *networkAddressTranslator) translateOutbound(from Chunk) (Chunk, error) { //nolint:cyclop n.mutex.Lock() defer n.mutex.Unlock() to := from.Clone() if from.Network() == udp { //nolint:nestif if n.natType.Mode == NATModeNAT1To1 { // 1:1 NAT behavior srcAddr := from.SourceAddr().(*net.UDPAddr) //nolint:forcetypeassert srcIP := n.getPairedMappedIP(srcAddr.IP) if srcIP == nil { n.log.Debugf("[%s] drop outbound chunk %s with not route", n.name, from.String()) return nil, nil // nolint:nilnil } srcPort := srcAddr.Port if err := to.setSourceAddr(fmt.Sprintf("%s:%d", srcIP.String(), srcPort)); err != nil { return nil, err } } else { // Normal (NAPT) behavior var bound, filterKey string switch n.natType.MappingBehavior { case EndpointIndependent: bound = "" case EndpointAddrDependent: bound = from.getDestinationIP().String() case EndpointAddrPortDependent: bound = from.DestinationAddr().String() } switch n.natType.FilteringBehavior { case EndpointIndependent: filterKey = "" case EndpointAddrDependent: filterKey = from.getDestinationIP().String() case EndpointAddrPortDependent: filterKey = from.DestinationAddr().String() } oKey := fmt.Sprintf("udp:%s:%s", from.SourceAddr().String(), bound) mapp := n.findOutboundMapping(oKey) if mapp == nil { // Create a new mapping mappedPort := 0xC000 + n.udpPortCounter n.udpPortCounter++ mapp = &mapping{ proto: from.SourceAddr().Network(), local: from.SourceAddr().String(), bound: bound, mapped: fmt.Sprintf("%s:%d", n.mappedIPs[0].String(), mappedPort), filters: map[string]struct{}{}, expires: time.Now().Add(n.natType.MappingLifeTime), } n.outboundMap[oKey] = mapp iKey := fmt.Sprintf("udp:%s", mapp.mapped) n.log.Debugf("[%s] created a new NAT binding oKey=%s iKey=%s", n.name, oKey, iKey) mapp.filters[filterKey] = struct{}{} n.log.Debugf("[%s] permit access from %s to %s", n.name, filterKey, mapp.mapped) n.inboundMap[iKey] = mapp } else if _, ok := mapp.filters[filterKey]; !ok { n.log.Debugf("[%s] permit access from %s to %s", n.name, filterKey, mapp.mapped) mapp.filters[filterKey] = struct{}{} } if err := to.setSourceAddr(mapp.mapped); err != nil { return nil, err } } n.log.Debugf("[%s] translate outbound chunk from %s to %s", n.name, from.String(), to.String()) return to, nil } return nil, errNonUDPTranslationNotSupported } func (n *networkAddressTranslator) translateInbound(from Chunk) (Chunk, error) { //nolint:cyclop n.mutex.Lock() defer n.mutex.Unlock() to := from.Clone() if from.Network() == udp { //nolint:nestif if n.natType.Mode == NATModeNAT1To1 { // 1:1 NAT behavior dstAddr := from.DestinationAddr().(*net.UDPAddr) //nolint:forcetypeassert dstIP := n.getPairedLocalIP(dstAddr.IP) if dstIP == nil { return nil, fmt.Errorf("drop %s as %w", from.String(), errNoAssociatedLocalAddress) } dstPort := from.DestinationAddr().(*net.UDPAddr).Port //nolint:forcetypeassert if err := to.setDestinationAddr(fmt.Sprintf("%s:%d", dstIP, dstPort)); err != nil { return nil, err } } else { // Normal (NAPT) behavior iKey := fmt.Sprintf("udp:%s", from.DestinationAddr().String()) mapping := n.findInboundMapping(iKey) if mapping == nil { return nil, fmt.Errorf("drop %s as %w", from.String(), errNoNATBindingFound) } var filterKey string switch n.natType.FilteringBehavior { case EndpointIndependent: filterKey = "" case EndpointAddrDependent: filterKey = from.getSourceIP().String() case EndpointAddrPortDependent: filterKey = from.SourceAddr().String() } if _, ok := mapping.filters[filterKey]; !ok { return nil, fmt.Errorf("drop %s as the remote %s %w", from.String(), filterKey, errHasNoPermission) } // See RFC 4847 Section 4.3. Mapping Refresh // a) Inbound refresh may be useful for applications with no outgoing // UDP traffic. However, allowing inbound refresh may allow an // external attacker or misbehaving application to keep a mapping // alive indefinitely. This may be a security risk. Also, if the // process is repeated with different ports, over time, it could // use up all the ports on the NAT. if err := to.setDestinationAddr(mapping.local); err != nil { return nil, err } } n.log.Debugf("[%s] translate inbound chunk from %s to %s", n.name, from.String(), to.String()) return to, nil } return nil, errNonUDPTranslationNotSupported } // caller must hold the mutex. func (n *networkAddressTranslator) findOutboundMapping(oKey string) *mapping { now := time.Now() m, ok := n.outboundMap[oKey] if ok { // check if this mapping is expired if now.After(m.expires) { n.removeMapping(m) m = nil // expired } else { m.expires = time.Now().Add(n.natType.MappingLifeTime) } } return m } // caller must hold the mutex. func (n *networkAddressTranslator) findInboundMapping(iKey string) *mapping { now := time.Now() m, ok := n.inboundMap[iKey] if !ok { return nil } // check if this mapping is expired if now.After(m.expires) { n.removeMapping(m) return nil } return m } // caller must hold the mutex. func (n *networkAddressTranslator) removeMapping(m *mapping) { oKey := fmt.Sprintf("%s:%s:%s", m.proto, m.local, m.bound) iKey := fmt.Sprintf("%s:%s", m.proto, m.mapped) delete(n.outboundMap, oKey) delete(n.inboundMap, iKey) } golang-github-pion-transport-v3-3.0.8/vnet/nat_test.go000066400000000000000000000530471507057301700227360ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "net" "testing" "time" "github.com/pion/logging" "github.com/stretchr/testify/assert" ) // oic: outbound internal chunk // oec: outbound external chunk // iic: inbound internal chunk // iec: inbound external chunk const demoIP = "1.2.3.4" func TestNATTypeDefaults(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() nat, err := newNAT(&natConfig{ natType: NATType{}, mappedIPs: []net.IP{net.ParseIP(demoIP)}, loggerFactory: loggerFactory, }) assert.NoError(t, err, "should succeed") assert.Equal(t, EndpointIndependent, nat.natType.MappingBehavior, "should match") assert.Equal(t, EndpointIndependent, nat.natType.FilteringBehavior, "should match") assert.False(t, nat.natType.Hairpinning, "should be false") assert.False(t, nat.natType.PortPreservation, "should be false") assert.Equal(t, defaultNATMappingLifeTime, nat.natType.MappingLifeTime, "should be false") } func TestNATMappingBehavior(t *testing.T) { //nolint:maintidx loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") t.Run("full-cone NAT", func(t *testing.T) { nat, err := newNAT(&natConfig{ natType: NATType{ MappingBehavior: EndpointIndependent, FilteringBehavior: EndpointIndependent, Hairpinning: false, MappingLifeTime: 30 * time.Second, }, mappedIPs: []net.IP{net.ParseIP(demoIP)}, loggerFactory: loggerFactory, }) assert.NoError(t, err, "should succeed") src := &net.UDPAddr{ IP: net.ParseIP("192.168.0.2"), Port: 1234, } dst := &net.UDPAddr{ IP: net.ParseIP("5.6.7.8"), Port: 5678, } oic := newChunkUDP(src, dst) oec, err := nat.translateOutbound(oic) assert.Nil(t, err, "should succeed") assert.Equal(t, 1, len(nat.outboundMap), "should match") assert.Equal(t, 1, len(nat.inboundMap), "should match") log.Debugf("o-original : %s", oic.String()) log.Debugf("o-translated: %s", oec.String()) //nolint:forcetypeassert iec := newChunkUDP( &net.UDPAddr{ IP: dst.IP, Port: dst.Port, }, &net.UDPAddr{ IP: oec.(*chunkUDP).sourceIP, Port: oec.(*chunkUDP).sourcePort, }, ) log.Debugf("i-original : %s", iec.String()) iic, err := nat.translateInbound(iec) assert.Nil(t, err, "should succeed") log.Debugf("i-translated: %s", iic.String()) //nolint:forcetypeassert assert.Equal(t, oic.SourceAddr().String(), iic.(*chunkUDP).DestinationAddr().String(), "should match") // packet with dest addr that does not exist in the mapping table // will be dropped //nolint:forcetypeassert iec = newChunkUDP( &net.UDPAddr{ IP: dst.IP, Port: dst.Port, }, &net.UDPAddr{ IP: oec.(*chunkUDP).sourceIP, Port: oec.(*chunkUDP).sourcePort + 1, }, ) _, err = nat.translateInbound(iec) log.Debug(err.Error()) assert.NotNil(t, err, "should fail (dropped)") // packet from any addr will be accepted (full-cone) //nolint:forcetypeassert iec = newChunkUDP( &net.UDPAddr{ IP: dst.IP, Port: 7777, }, &net.UDPAddr{ IP: oec.(*chunkUDP).sourceIP, Port: oec.(*chunkUDP).sourcePort, }, ) _, err = nat.translateInbound(iec) assert.Nil(t, err, "should succeed") }) t.Run("addr-restricted-cone NAT", func(t *testing.T) { nat, err := newNAT(&natConfig{ natType: NATType{ MappingBehavior: EndpointIndependent, FilteringBehavior: EndpointAddrDependent, Hairpinning: false, MappingLifeTime: 30 * time.Second, }, mappedIPs: []net.IP{net.ParseIP(demoIP)}, loggerFactory: loggerFactory, }) assert.NoError(t, err, "should succeed") src := &net.UDPAddr{ IP: net.ParseIP("192.168.0.2"), Port: 1234, } dst := &net.UDPAddr{ IP: net.ParseIP("5.6.7.8"), Port: 5678, } oic := newChunkUDP(src, dst) log.Debugf("o-original : %s", oic.String()) oec, err := nat.translateOutbound(oic) assert.Nil(t, err, "should succeed") assert.Equal(t, 1, len(nat.outboundMap), "should match") assert.Equal(t, 1, len(nat.inboundMap), "should match") log.Debugf("o-translated: %s", oec.String()) // sending different (IP: 5.6.7.9) won't create a new mapping oic2 := newChunkUDP(&net.UDPAddr{ IP: net.ParseIP("192.168.0.2"), Port: 1234, }, &net.UDPAddr{ IP: net.ParseIP("5.6.7.9"), Port: 9000, }) oec2, err := nat.translateOutbound(oic2) assert.Nil(t, err, "should succeed") assert.Equal(t, 1, len(nat.outboundMap), "should match") assert.Equal(t, 1, len(nat.inboundMap), "should match") log.Debugf("o-translated: %s", oec2.String()) //nolint:forcetypeassert iec := newChunkUDP( &net.UDPAddr{ IP: dst.IP, Port: dst.Port, }, &net.UDPAddr{ IP: oec.(*chunkUDP).sourceIP, Port: oec.(*chunkUDP).sourcePort, }, ) log.Debugf("i-original : %s", iec.String()) iic, err := nat.translateInbound(iec) if !assert.NoError(t, err, "should succeed") { return } log.Debugf("i-translated: %s", iic.String()) //nolint:forcetypeassert assert.Equal(t, oic.SourceAddr().String(), iic.(*chunkUDP).DestinationAddr().String(), "should match") // packet with dest addr that does not exist in the mapping table // will be dropped //nolint:forcetypeassert iec = newChunkUDP( &net.UDPAddr{ IP: dst.IP, Port: dst.Port, }, &net.UDPAddr{ IP: oec.(*chunkUDP).sourceIP, Port: oec.(*chunkUDP).sourcePort + 1, }, ) _, err = nat.translateInbound(iec) log.Debug(err.Error()) assert.NotNil(t, err, "should fail (dropped)") // packet from any port will be accepted (restricted-cone) //nolint:forcetypeassert iec = newChunkUDP( &net.UDPAddr{ IP: dst.IP, Port: 7777, }, &net.UDPAddr{ IP: oec.(*chunkUDP).sourceIP, Port: oec.(*chunkUDP).sourcePort, }, ) _, err = nat.translateInbound(iec) assert.Nil(t, err, "should succeed") // packet from different addr will be dropped (restricted-cone) //nolint:forcetypeassert iec = newChunkUDP( &net.UDPAddr{ IP: net.ParseIP("6.6.6.6"), Port: dst.Port, }, &net.UDPAddr{ IP: oec.(*chunkUDP).sourceIP, Port: oec.(*chunkUDP).sourcePort, }, ) _, err = nat.translateInbound(iec) log.Debug(err.Error()) assert.NotNil(t, err, "should fail (dropped)") }) t.Run("port-restricted-cone NAT", func(t *testing.T) { nat, err := newNAT(&natConfig{ natType: NATType{ MappingBehavior: EndpointIndependent, FilteringBehavior: EndpointAddrPortDependent, Hairpinning: false, MappingLifeTime: 30 * time.Second, }, mappedIPs: []net.IP{net.ParseIP(demoIP)}, loggerFactory: loggerFactory, }) assert.NoError(t, err, "should succeed") src := &net.UDPAddr{ IP: net.ParseIP("192.168.0.2"), Port: 1234, } dst := &net.UDPAddr{ IP: net.ParseIP("5.6.7.8"), Port: 5678, } oic := newChunkUDP(src, dst) log.Debugf("o-original : %s", oic.String()) oec, err := nat.translateOutbound(oic) assert.Nil(t, err, "should succeed") assert.Equal(t, 1, len(nat.outboundMap), "should match") assert.Equal(t, 1, len(nat.inboundMap), "should match") log.Debugf("o-translated: %s", oec.String()) // sending different (IP: 5.6.7.9) won't create a new mapping oic2 := newChunkUDP(&net.UDPAddr{ IP: net.ParseIP("192.168.0.2"), Port: 1234, }, &net.UDPAddr{ IP: net.ParseIP("5.6.7.9"), Port: 9000, }) oec2, err := nat.translateOutbound(oic2) assert.Nil(t, err, "should succeed") assert.Equal(t, 1, len(nat.outboundMap), "should match") assert.Equal(t, 1, len(nat.inboundMap), "should match") log.Debugf("o-translated: %s", oec2.String()) //nolint:forcetypeassert iec := newChunkUDP( &net.UDPAddr{ IP: dst.IP, Port: dst.Port, }, &net.UDPAddr{ IP: oec.(*chunkUDP).sourceIP, Port: oec.(*chunkUDP).sourcePort, }, ) log.Debugf("i-original : %s", iec.String()) iic, err := nat.translateInbound(iec) assert.Nil(t, err, "should succeed") log.Debugf("i-translated: %s", iic.String()) //nolint:forcetypeassert assert.Equal(t, oic.SourceAddr().String(), iic.(*chunkUDP).DestinationAddr().String(), "should match") // packet with dest addr that does not exist in the mapping table // will be dropped //nolint:forcetypeassert iec = newChunkUDP( &net.UDPAddr{ IP: dst.IP, Port: dst.Port, }, &net.UDPAddr{ IP: oec.(*chunkUDP).sourceIP, Port: oec.(*chunkUDP).sourcePort + 1, }, ) _, err = nat.translateInbound(iec) assert.NotNil(t, err, "should fail (dropped)") // packet from different port will be dropped (port-restricted-cone) //nolint:forcetypeassert iec = newChunkUDP( &net.UDPAddr{ IP: dst.IP, Port: 7777, }, &net.UDPAddr{ IP: oec.(*chunkUDP).sourceIP, Port: oec.(*chunkUDP).sourcePort, }, ) _, err = nat.translateInbound(iec) assert.NotNil(t, err, "should fail (dropped)") // packet from different addr will be dropped (restricted-cone) //nolint:forcetypeassert iec = newChunkUDP( &net.UDPAddr{ IP: net.ParseIP("6.6.6.6"), Port: dst.Port, }, &net.UDPAddr{ IP: oec.(*chunkUDP).sourceIP, Port: oec.(*chunkUDP).sourcePort, }, ) _, err = nat.translateInbound(iec) assert.NotNil(t, err, "should fail (dropped)") }) t.Run("symmetric NAT addr dependent mapping", func(t *testing.T) { //nolint:dupl nat, err := newNAT(&natConfig{ natType: NATType{ MappingBehavior: EndpointAddrDependent, FilteringBehavior: EndpointAddrDependent, Hairpinning: false, MappingLifeTime: 30 * time.Second, }, mappedIPs: []net.IP{net.ParseIP(demoIP)}, loggerFactory: loggerFactory, }) assert.NoError(t, err, "should succeed") oic1 := newChunkUDP( &net.UDPAddr{ IP: net.ParseIP("192.168.0.2"), Port: 1234, }, &net.UDPAddr{ IP: net.ParseIP("5.6.7.8"), Port: 5678, }, ) oic2 := newChunkUDP( &net.UDPAddr{ IP: net.ParseIP("192.168.0.2"), Port: 1234, }, &net.UDPAddr{ IP: net.ParseIP("5.6.7.100"), Port: 5678, }, ) oic3 := newChunkUDP( &net.UDPAddr{ IP: net.ParseIP("192.168.0.2"), Port: 1234, }, &net.UDPAddr{ IP: net.ParseIP("5.6.7.8"), Port: 6000, }, ) log.Debugf("o-original : %s", oic1.String()) log.Debugf("o-original : %s", oic2.String()) log.Debugf("o-original : %s", oic3.String()) oec1, err := nat.translateOutbound(oic1) assert.Nil(t, err, "should succeed") oec2, err := nat.translateOutbound(oic2) assert.Nil(t, err, "should succeed") oec3, err := nat.translateOutbound(oic3) assert.Nil(t, err, "should succeed") assert.Equal(t, 2, len(nat.outboundMap), "should match") assert.Equal(t, 2, len(nat.inboundMap), "should match") log.Debugf("o-translated: %s", oec1.String()) log.Debugf("o-translated: %s", oec2.String()) log.Debugf("o-translated: %s", oec3.String()) assert.NotEqual( t, oec1.(*chunkUDP).sourcePort, //nolint:forcetypeassert oec2.(*chunkUDP).sourcePort, //nolint:forcetypeassert "should not match", ) assert.Equal( t, oec1.(*chunkUDP).sourcePort, //nolint:forcetypeassert oec3.(*chunkUDP).sourcePort, //nolint:forcetypeassert "should match", ) }) t.Run("symmetric NAT port dependent mapping", func(t *testing.T) { //nolint:dupl nat, err := newNAT(&natConfig{ natType: NATType{ MappingBehavior: EndpointAddrPortDependent, FilteringBehavior: EndpointAddrPortDependent, Hairpinning: false, MappingLifeTime: 30 * time.Second, }, mappedIPs: []net.IP{net.ParseIP(demoIP)}, loggerFactory: loggerFactory, }) assert.NoError(t, err, "should succeed") oic1 := newChunkUDP( &net.UDPAddr{ IP: net.ParseIP("192.168.0.2"), Port: 1234, }, &net.UDPAddr{ IP: net.ParseIP("5.6.7.8"), Port: 5678, }, ) oic2 := newChunkUDP( &net.UDPAddr{ IP: net.ParseIP("192.168.0.2"), Port: 1234, }, &net.UDPAddr{ IP: net.ParseIP("5.6.7.100"), Port: 5678, }, ) oic3 := newChunkUDP( &net.UDPAddr{ IP: net.ParseIP("192.168.0.2"), Port: 1234, }, &net.UDPAddr{ IP: net.ParseIP("5.6.7.8"), Port: 6000, }, ) log.Debugf("o-original : %s", oic1.String()) log.Debugf("o-original : %s", oic2.String()) log.Debugf("o-original : %s", oic3.String()) oec1, err := nat.translateOutbound(oic1) assert.Nil(t, err, "should succeed") oec2, err := nat.translateOutbound(oic2) assert.Nil(t, err, "should succeed") oec3, err := nat.translateOutbound(oic3) assert.Nil(t, err, "should succeed") assert.Equal(t, 3, len(nat.outboundMap), "should match") assert.Equal(t, 3, len(nat.inboundMap), "should match") log.Debugf("o-translated: %s", oec1.String()) log.Debugf("o-translated: %s", oec2.String()) log.Debugf("o-translated: %s", oec3.String()) assert.NotEqual( t, oec1.(*chunkUDP).sourcePort, //nolint:forcetypeassert oec2.(*chunkUDP).sourcePort, //nolint:forcetypeassert "should not match", ) assert.NotEqual( t, oec1.(*chunkUDP).sourcePort, //nolint:forcetypeassert oec3.(*chunkUDP).sourcePort, //nolint:forcetypeassert "should match", ) }) } func TestNATMappingTimeout(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") t.Run("refresh on outbound", func(t *testing.T) { nat, err := newNAT(&natConfig{ natType: NATType{ MappingBehavior: EndpointIndependent, FilteringBehavior: EndpointIndependent, Hairpinning: false, MappingLifeTime: 100 * time.Millisecond, }, mappedIPs: []net.IP{net.ParseIP(demoIP)}, loggerFactory: loggerFactory, }) assert.NoError(t, err, "should succeed") src := &net.UDPAddr{ IP: net.ParseIP("192.168.0.2"), Port: 1234, } dst := &net.UDPAddr{ IP: net.ParseIP("5.6.7.8"), Port: 5678, } oic := newChunkUDP(src, dst) oec, err := nat.translateOutbound(oic) assert.Nil(t, err, "should succeed") assert.Equal(t, 1, len(nat.outboundMap), "should match") assert.Equal(t, 1, len(nat.inboundMap), "should match") log.Debugf("o-original : %s", oic.String()) log.Debugf("o-translated: %s", oec.String()) // record mapped addr mapped := oec.(*chunkUDP).SourceAddr().String() //nolint:forcetypeassert time.Sleep(75 * time.Millisecond) // refresh oec, err = nat.translateOutbound(oic) assert.Nil(t, err, "should succeed") assert.Equal(t, 1, len(nat.outboundMap), "should match") assert.Equal(t, 1, len(nat.inboundMap), "should match") log.Debugf("o-original : %s", oic.String()) log.Debugf("o-translated: %s", oec.String()) assert.Equal(t, mapped, oec.(*chunkUDP).SourceAddr().String(), "mapped addr should match") //nolint:forcetypeassert // sleep long enough for the mapping to expire time.Sleep(125 * time.Millisecond) // refresh after expiration oec, err = nat.translateOutbound(oic) assert.Nil(t, err, "should succeed") assert.Equal(t, 1, len(nat.outboundMap), "should match") assert.Equal(t, 1, len(nat.inboundMap), "should match") log.Debugf("o-original : %s", oic.String()) log.Debugf("o-translated: %s", oec.String()) assert.NotEqual( t, mapped, oec.(*chunkUDP).SourceAddr().String(), //nolint:forcetypeassert "mapped addr should not match", ) }) t.Run("outbound detects timeout", func(t *testing.T) { nat, err := newNAT(&natConfig{ natType: NATType{ MappingBehavior: EndpointIndependent, FilteringBehavior: EndpointIndependent, Hairpinning: false, MappingLifeTime: 100 * time.Millisecond, }, mappedIPs: []net.IP{net.ParseIP(demoIP)}, loggerFactory: loggerFactory, }) assert.NoError(t, err, "should succeed") src := &net.UDPAddr{ IP: net.ParseIP("192.168.0.2"), Port: 1234, } dst := &net.UDPAddr{ IP: net.ParseIP("5.6.7.8"), Port: 5678, } oic := newChunkUDP(src, dst) oec, err := nat.translateOutbound(oic) assert.Nil(t, err, "should succeed") assert.Equal(t, 1, len(nat.outboundMap), "should match") assert.Equal(t, 1, len(nat.inboundMap), "should match") log.Debugf("o-original : %s", oic.String()) log.Debugf("o-translated: %s", oec.String()) // sleep long enough for the mapping to expire time.Sleep(125 * time.Millisecond) //nolint:forcetypeassert iec := newChunkUDP( &net.UDPAddr{ IP: dst.IP, Port: dst.Port, }, &net.UDPAddr{ IP: oec.(*chunkUDP).sourceIP, Port: oec.(*chunkUDP).sourcePort, }, ) log.Debugf("i-original : %s", iec.String()) _, err = nat.translateInbound(iec) assert.NotNil(t, err, "should drop") assert.Empty(t, nat.outboundMap, "should have no binding") assert.Empty(t, nat.inboundMap, "should have no binding") }) } func TestNAT1To1Behavior(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") t.Run("1:1 NAT with one mapping", func(t *testing.T) { nat, err := newNAT(&natConfig{ natType: NATType{ Mode: NATModeNAT1To1, }, mappedIPs: []net.IP{net.ParseIP(demoIP)}, localIPs: []net.IP{net.ParseIP("10.0.0.1")}, loggerFactory: loggerFactory, }) if !assert.NoError(t, err, "should succeed") { return } src := &net.UDPAddr{ IP: net.ParseIP("10.0.0.1"), Port: 1234, } dst := &net.UDPAddr{ IP: net.ParseIP("5.6.7.8"), Port: 5678, } oic := newChunkUDP(src, dst) oec, err := nat.translateOutbound(oic) assert.Nil(t, err, "should succeed") assert.Empty(t, nat.outboundMap, "should match") assert.Empty(t, nat.inboundMap, "should match") log.Debugf("o-original : %s", oic.String()) log.Debugf("o-translated: %s", oec.String()) assert.Equal(t, "1.2.3.4:1234", oec.SourceAddr().String(), "should match") //nolint:forcetypeassert iec := newChunkUDP( &net.UDPAddr{ IP: dst.IP, Port: dst.Port, }, &net.UDPAddr{ IP: oec.(*chunkUDP).sourceIP, Port: oec.(*chunkUDP).sourcePort, }, ) log.Debugf("i-original : %s", iec.String()) iic, err := nat.translateInbound(iec) assert.Nil(t, err, "should succeed") log.Debugf("i-translated: %s", iic.String()) assert.Equal(t, oic.SourceAddr().String(), iic.DestinationAddr().String(), "should match") }) t.Run("1:1 NAT with more than one mapping", func(t *testing.T) { nat, err := newNAT(&natConfig{ natType: NATType{ Mode: NATModeNAT1To1, }, mappedIPs: []net.IP{ net.ParseIP(demoIP), net.ParseIP("1.2.3.5"), }, localIPs: []net.IP{ net.ParseIP("10.0.0.1"), net.ParseIP("10.0.0.2"), }, loggerFactory: loggerFactory, }) if !assert.NoError(t, err, "should succeed") { return } // outbound translation before := newChunkUDP( &net.UDPAddr{ IP: net.ParseIP("10.0.0.1"), Port: 1234, }, &net.UDPAddr{ IP: net.ParseIP("5.6.7.8"), Port: 5678, }) after, err := nat.translateOutbound(before) if !assert.NoError(t, err, "should succeed") { return } assert.Equal(t, "1.2.3.4:1234", after.SourceAddr().String(), "should match") before = newChunkUDP( &net.UDPAddr{ IP: net.ParseIP("10.0.0.2"), Port: 1234, }, &net.UDPAddr{ IP: net.ParseIP("5.6.7.8"), Port: 5678, }) after, err = nat.translateOutbound(before) if !assert.NoError(t, err, "should succeed") { return } assert.Equal(t, "1.2.3.5:1234", after.SourceAddr().String(), "should match") // inbound translation before = newChunkUDP( &net.UDPAddr{ IP: net.ParseIP("5.6.7.8"), Port: 5678, }, &net.UDPAddr{ IP: net.ParseIP(demoIP), Port: 2525, }) after, err = nat.translateInbound(before) if !assert.NoError(t, err, "should succeed") { return } assert.Equal(t, "10.0.0.1:2525", after.DestinationAddr().String(), "should match") before = newChunkUDP( &net.UDPAddr{ IP: net.ParseIP("5.6.7.8"), Port: 5678, }, &net.UDPAddr{ IP: net.ParseIP("1.2.3.5"), Port: 9847, }) after, err = nat.translateInbound(before) if !assert.NoError(t, err, "should succeed") { return } assert.Equal(t, "10.0.0.2:9847", after.DestinationAddr().String(), "should match") }) t.Run("1:1 NAT failure", func(t *testing.T) { // 1:1 NAT requires more than one mapping _, err := newNAT(&natConfig{ natType: NATType{ Mode: NATModeNAT1To1, }, loggerFactory: loggerFactory, }) assert.Error(t, err, "should fail") // 1:1 NAT requires the same number of mappedIPs and localIPs _, err = newNAT(&natConfig{ natType: NATType{ Mode: NATModeNAT1To1, }, mappedIPs: []net.IP{ net.ParseIP(demoIP), net.ParseIP("1.2.3.5"), }, localIPs: []net.IP{ net.ParseIP("10.0.0.1"), }, loggerFactory: loggerFactory, }) assert.Error(t, err, "should fail") // drop outbound or inbound chunk with no route in 1:1 NAT nat, err := newNAT(&natConfig{ natType: NATType{ Mode: NATModeNAT1To1, }, mappedIPs: []net.IP{ net.ParseIP(demoIP), }, localIPs: []net.IP{ net.ParseIP("10.0.0.1"), }, loggerFactory: loggerFactory, }) assert.NoError(t, err, "should succeed") before := newChunkUDP( &net.UDPAddr{ IP: net.ParseIP("10.0.0.2"), // no external mapping for this Port: 1234, }, &net.UDPAddr{ IP: net.ParseIP("5.6.7.8"), Port: 5678, }) after, err := nat.translateOutbound(before) if !assert.NoError(t, err, "should succeed") { return } if !assert.Nil(t, after, "should be nil") { return } before = newChunkUDP( &net.UDPAddr{ IP: net.ParseIP("5.6.7.8"), Port: 5678, }, &net.UDPAddr{ IP: net.ParseIP("10.0.0.2"), // no local mapping for this Port: 1234, }) _, err = nat.translateInbound(before) assert.Error(t, err, "should fail") }) } golang-github-pion-transport-v3-3.0.8/vnet/net.go000066400000000000000000000335111507057301700216750ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "encoding/binary" "errors" "fmt" "math/rand" "net" "strconv" "strings" "sync" "github.com/pion/transport/v3" ) const ( lo0String = "lo0String" udp = "udp" udp4 = "udp4" ) var ( macAddrCounter uint64 = 0xBEEFED910200 //nolint:gochecknoglobals errNoInterface = errors.New("no interface is available") errUnexpectedNetwork = errors.New("unexpected network") errCantAssignRequestedAddr = errors.New("can't assign requested address") errUnknownNetwork = errors.New("unknown network") errNoRouterLinked = errors.New("no router linked") errInvalidPortNumber = errors.New("invalid port number") errUnexpectedTypeSwitchFailure = errors.New("unexpected type-switch failure") errBindFailedFor = errors.New("bind failed for") errEndPortLessThanStart = errors.New("end port is less than the start") errPortSpaceExhausted = errors.New("port space exhausted") ) func newMACAddress() net.HardwareAddr { b := make([]byte, 8) binary.BigEndian.PutUint64(b, macAddrCounter) macAddrCounter++ return b[2:] } // Net represents a local network stack equivalent to a set of layers from NIC // up to the transport (UDP / TCP) layer. type Net struct { interfaces []*transport.Interface // read-only staticIPs []net.IP // read-only router *Router // read-only udpConns *udpConnMap // read-only mutex sync.RWMutex } // Compile-time assertion. var _ transport.Net = &Net{} func (v *Net) _getInterfaces() ([]*transport.Interface, error) { if len(v.interfaces) == 0 { return nil, errNoInterface } return v.interfaces, nil } // Interfaces returns a list of the system's network interfaces. func (v *Net) Interfaces() ([]*transport.Interface, error) { v.mutex.RLock() defer v.mutex.RUnlock() return v._getInterfaces() } // caller must hold the mutex (read). func (v *Net) _getInterface(ifName string) (*transport.Interface, error) { ifs, err := v._getInterfaces() if err != nil { return nil, err } for _, ifc := range ifs { if ifc.Name == ifName { return ifc, nil } } return nil, fmt.Errorf("%w: %s", transport.ErrInterfaceNotFound, ifName) } func (v *Net) getInterface(ifName string) (*transport.Interface, error) { v.mutex.RLock() defer v.mutex.RUnlock() return v._getInterface(ifName) } // InterfaceByIndex returns the interface specified by index. // // On Solaris, it returns one of the logical network interfaces // sharing the logical data link; for more precision use // InterfaceByName. func (v *Net) InterfaceByIndex(index int) (*transport.Interface, error) { for _, ifc := range v.interfaces { if ifc.Index == index { return ifc, nil } } return nil, fmt.Errorf("%w: index=%d", transport.ErrInterfaceNotFound, index) } // InterfaceByName returns the interface specified by name. func (v *Net) InterfaceByName(ifName string) (*transport.Interface, error) { return v.getInterface(ifName) } // caller must hold the mutex. func (v *Net) getAllIPAddrs(ipv6 bool) []net.IP { ips := []net.IP{} for _, ifc := range v.interfaces { addrs, err := ifc.Addrs() if err != nil { continue } for _, addr := range addrs { var ip net.IP if ipNet, ok := addr.(*net.IPNet); ok { ip = ipNet.IP } else if ipAddr, ok := addr.(*net.IPAddr); ok { ip = ipAddr.IP } else { continue } if !ipv6 { if ip.To4() != nil { ips = append(ips, ip) } } } } return ips } func (v *Net) setRouter(r *Router) error { v.mutex.Lock() defer v.mutex.Unlock() v.router = r return nil } func (v *Net) onInboundChunk(c Chunk) { v.mutex.Lock() defer v.mutex.Unlock() if c.Network() == udp { if conn, ok := v.udpConns.find(c.DestinationAddr()); ok { conn.onInboundChunk(c) } } } // caller must hold the mutex. func (v *Net) _dialUDP(network string, locAddr, remAddr *net.UDPAddr) (transport.UDPConn, error) { //nolint:cyclop // validate network if network != udp && network != udp4 { return nil, fmt.Errorf("%w: %s", errUnexpectedNetwork, network) } if locAddr == nil { locAddr = &net.UDPAddr{ IP: net.IPv4zero, } } else if locAddr.IP == nil { locAddr.IP = net.IPv4zero } // validate address. do we have that address? if !v.hasIPAddr(locAddr.IP) { return nil, &net.OpError{ Op: "listen", Net: network, Addr: locAddr, Err: fmt.Errorf("bind: %w", errCantAssignRequestedAddr), } } if locAddr.Port == 0 { // choose randomly from the range between 5000 and 5999 port, err := v.assignPort(locAddr.IP, 5000, 5999) if err != nil { return nil, &net.OpError{ Op: "listen", Net: network, Addr: locAddr, Err: err, } } locAddr.Port = port } else if _, ok := v.udpConns.find(locAddr); ok { return nil, &net.OpError{ Op: "listen", Net: network, Addr: locAddr, Err: fmt.Errorf("bind: %w", errAddressAlreadyInUse), } } conn, err := newUDPConn(locAddr, remAddr, v) if err != nil { return nil, err } err = v.udpConns.insert(conn) if err != nil { return nil, err } return conn, nil } // ListenPacket announces on the local network address. func (v *Net) ListenPacket(network string, address string) (net.PacketConn, error) { v.mutex.Lock() defer v.mutex.Unlock() locAddr, err := v.ResolveUDPAddr(network, address) if err != nil { return nil, err } return v._dialUDP(network, locAddr, nil) } // ListenUDP acts like ListenPacket for UDP networks. func (v *Net) ListenUDP(network string, locAddr *net.UDPAddr) (transport.UDPConn, error) { v.mutex.Lock() defer v.mutex.Unlock() return v._dialUDP(network, locAddr, nil) } // DialUDP acts like Dial for UDP networks. func (v *Net) DialUDP(network string, locAddr, remAddr *net.UDPAddr) (transport.UDPConn, error) { v.mutex.Lock() defer v.mutex.Unlock() return v._dialUDP(network, locAddr, remAddr) } // Dial connects to the address on the named network. func (v *Net) Dial(network string, address string) (net.Conn, error) { v.mutex.Lock() defer v.mutex.Unlock() remAddr, err := v.ResolveUDPAddr(network, address) if err != nil { return nil, err } // Determine source address srcIP := v.determineSourceIP(nil, remAddr.IP) locAddr := &net.UDPAddr{IP: srcIP, Port: 0} return v._dialUDP(network, locAddr, remAddr) } // ResolveIPAddr returns an address of IP end point. func (v *Net) ResolveIPAddr(_, address string) (*net.IPAddr, error) { var err error // Check if host is a domain name ip := net.ParseIP(address) if ip == nil { //nolint:nestif address = strings.ToLower(address) if address == "localhost" { ip = net.IPv4(127, 0, 0, 1) } else { // host is a domain name. resolve IP address by the name if v.router == nil { return nil, errNoRouterLinked } ip, err = v.router.resolver.lookUp(address) if err != nil { return nil, err } } } return &net.IPAddr{ IP: ip, }, nil } // ResolveUDPAddr returns an address of UDP end point. func (v *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) { if network != udp && network != udp4 { return nil, fmt.Errorf("%w %s", errUnknownNetwork, network) } host, sPort, err := net.SplitHostPort(address) if err != nil { return nil, err } ipAddress, err := v.ResolveIPAddr("ip", host) if err != nil { return nil, err } port, err := strconv.Atoi(sPort) if err != nil { return nil, errInvalidPortNumber } udpAddr := &net.UDPAddr{ IP: ipAddress.IP, Zone: ipAddress.Zone, Port: port, } return udpAddr, nil } // ResolveTCPAddr returns an address of TCP end point. func (v *Net) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) { if network != udp && network != "udp4" { return nil, fmt.Errorf("%w %s", errUnknownNetwork, network) } host, sPort, err := net.SplitHostPort(address) if err != nil { return nil, err } ipAddr, err := v.ResolveIPAddr("ip", host) if err != nil { return nil, err } port, err := strconv.Atoi(sPort) if err != nil { return nil, errInvalidPortNumber } udpAddr := &net.TCPAddr{ IP: ipAddr.IP, Zone: ipAddr.Zone, Port: port, } return udpAddr, nil } func (v *Net) write(chunk Chunk) error { if chunk.Network() == udp { //nolint:nestif if udp, ok := chunk.(*chunkUDP); ok { if chunk.getDestinationIP().IsLoopback() { if conn, ok := v.udpConns.find(udp.DestinationAddr()); ok { conn.onInboundChunk(udp) } return nil } } else { return errUnexpectedTypeSwitchFailure } } if v.router == nil { return errNoRouterLinked } v.router.push(chunk) return nil } func (v *Net) onClosed(addr net.Addr) { if addr.Network() == udp { //nolint:errcheck v.udpConns.delete(addr) // #nosec } } // This method determines the srcIP based on the dstIP when locIP // is any IP address ("0.0.0.0" or "::"). If locIP is a non-any addr, // this method simply returns locIP. // caller must hold the mutex. func (v *Net) determineSourceIP(locIP, dstIP net.IP) net.IP { //nolint:cyclop if locIP != nil && !locIP.IsUnspecified() { return locIP } var srcIP net.IP if dstIP.IsLoopback() { //nolint:nestif srcIP = net.ParseIP("127.0.0.1") } else { ifc, err2 := v._getInterface("eth0") if err2 != nil { return nil } addrs, err2 := ifc.Addrs() if err2 != nil { return nil } if len(addrs) == 0 { return nil } var findIPv4 bool if locIP != nil { findIPv4 = (locIP.To4() != nil) } else { findIPv4 = (dstIP.To4() != nil) } for _, addr := range addrs { ip := addr.(*net.IPNet).IP //nolint:forcetypeassert if findIPv4 { if ip.To4() != nil { srcIP = ip break } } else { if ip.To4() == nil { srcIP = ip break } } } } return srcIP } // caller must hold the mutex. func (v *Net) hasIPAddr(ip net.IP) bool { //nolint:gocognit,cyclop for _, ifc := range v.interfaces { if addrs, err := ifc.Addrs(); err == nil { //nolint:nestif for _, addr := range addrs { var locIP net.IP if ipNet, ok := addr.(*net.IPNet); ok { locIP = ipNet.IP } else if ipAddr, ok := addr.(*net.IPAddr); ok { locIP = ipAddr.IP } else { continue } switch ip.String() { case "0.0.0.0": if locIP.To4() != nil { return true } case "::": if locIP.To4() == nil { return true } default: if locIP.Equal(ip) { return true } } } } } return false } // caller must hold the mutex. func (v *Net) allocateLocalAddr(ip net.IP, port int) error { // gather local IP addresses to bind var ips []net.IP if ip.IsUnspecified() { ips = v.getAllIPAddrs(ip.To4() == nil) } else if v.hasIPAddr(ip) { ips = []net.IP{ip} } if len(ips) == 0 { return fmt.Errorf("%w %s", errBindFailedFor, ip.String()) } // check if all these transport addresses are not in use for _, ip2 := range ips { addr := &net.UDPAddr{ IP: ip2, Port: port, } if _, ok := v.udpConns.find(addr); ok { return &net.OpError{ Op: "bind", Net: udp, Addr: addr, Err: fmt.Errorf("bind: %w", errAddressAlreadyInUse), } } } return nil } // caller must hold the mutex. func (v *Net) assignPort(ip net.IP, start, end int) (int, error) { // choose randomly from the range between start and end (inclusive) if end < start { return -1, errEndPortLessThanStart } space := end + 1 - start offset := rand.Intn(space) //nolint:gosec for i := 0; i < space; i++ { port := ((offset + i) % space) + start err := v.allocateLocalAddr(ip, port) if err == nil { return port, nil } } return -1, errPortSpaceExhausted } func (v *Net) getStaticIPs() []net.IP { return v.staticIPs } // NetConfig is a bag of configuration parameters passed to NewNet(). type NetConfig struct { // StaticIPs is an array of static IP addresses to be assigned for this Net. // If no static IP address is given, the router will automatically assign // an IP address. StaticIPs []string // StaticIP is deprecated. Use StaticIPs. StaticIP string } // NewNet creates an instance of a virtual network. // // By design, it always have lo0 and eth0 interfaces. // The lo0 has the address 127.0.0.1 assigned by default. // IP address for eth0 will be assigned when this Net is added to a router. func NewNet(config *NetConfig) (*Net, error) { lo0 := transport.NewInterface(net.Interface{ Index: 1, MTU: 16384, Name: lo0String, HardwareAddr: nil, Flags: net.FlagUp | net.FlagLoopback | net.FlagMulticast, }) lo0.AddAddress(&net.IPNet{ IP: net.ParseIP("127.0.0.1"), Mask: net.CIDRMask(8, 32), }) eth0 := transport.NewInterface(net.Interface{ Index: 2, MTU: 1500, Name: "eth0", HardwareAddr: newMACAddress(), Flags: net.FlagUp | net.FlagMulticast, }) var staticIPs []net.IP for _, ipStr := range config.StaticIPs { if ip := net.ParseIP(ipStr); ip != nil { staticIPs = append(staticIPs, ip) } } if len(config.StaticIP) > 0 { if ip := net.ParseIP(config.StaticIP); ip != nil { staticIPs = append(staticIPs, ip) } } return &Net{ interfaces: []*transport.Interface{lo0, eth0}, staticIPs: staticIPs, udpConns: newUDPConnMap(), }, nil } // DialTCP acts like Dial for TCP networks. func (v *Net) DialTCP(string, *net.TCPAddr, *net.TCPAddr) (transport.TCPConn, error) { return nil, transport.ErrNotSupported } // ListenTCP acts like Listen for TCP networks. func (v *Net) ListenTCP(string, *net.TCPAddr) (transport.TCPListener, error) { return nil, transport.ErrNotSupported } // CreateDialer creates an instance of vnet.Dialer. func (v *Net) CreateDialer(d *net.Dialer) transport.Dialer { return &dialer{ dialer: d, net: v, } } type dialer struct { dialer *net.Dialer net *Net } func (d *dialer) Dial(network, address string) (net.Conn, error) { return d.net.Dial(network, address) } golang-github-pion-transport-v3-3.0.8/vnet/net_test.go000066400000000000000000000524111507057301700227340ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "fmt" "net" "testing" "github.com/pion/logging" "github.com/pion/transport/v3" "github.com/stretchr/testify/assert" ) func TestNetVirtual(t *testing.T) { //nolint:gocyclo,cyclop,maintidx loggerFactory := logging.NewDefaultLoggerFactory() log := logging.NewDefaultLoggerFactory().NewLogger("test") t.Run("tnet.Interfaces", func(t *testing.T) { nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } intfs, err := nw.Interfaces() assert.Equal(t, 2, len(intfs), "should be one tnet.Interface") assert.NoError(t, err, "should succeed") for _, ifc := range intfs { switch ifc.Name { case lo0String: assert.Equal(t, 1, ifc.Index, "Index mismatch") assert.Equal(t, 16384, ifc.MTU, "MTU mismatch") assert.Equal(t, net.HardwareAddr(nil), ifc.HardwareAddr, "HardwareAddr mismatch") assert.Equal(t, net.FlagUp|net.FlagLoopback|net.FlagMulticast, ifc.Flags, "Flags mismatch") addrs, err := ifc.Addrs() assert.NoError(t, err, "should succeed") assert.Equal(t, 1, len(addrs), "should be one address") case "eth0": assert.Equal(t, 2, ifc.Index, "Index mismatch") assert.Equal(t, 1500, ifc.MTU, "MTU mismatch") assert.Equal(t, 6, len(ifc.HardwareAddr), "HardwareAddr length mismatch") assert.Equal(t, net.FlagUp|net.FlagMulticast, ifc.Flags, "Flags mismatch") _, err := ifc.Addrs() assert.NotNil(t, err, "should fail") default: assert.Fail(t, "unknown tnet.Interface: %v", ifc.Name) } if addrs, err := ifc.Addrs(); err == nil { for _, addr := range addrs { log.Debugf("[%d] %s:%s", ifc.Index, addr.Network(), addr.String()) } } } }) t.Run("tnet.InterfaceByName", func(t *testing.T) { nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } intfs, err := nw.Interfaces() assert.Equal(t, 2, len(intfs), "should be one tnet.Interface") assert.NoError(t, err, "should succeed") var ifc *transport.Interface ifc, err = nw.InterfaceByName(lo0String) assert.NoError(t, err, "should succeed") if ifc.Name == lo0String { assert.Equal(t, 1, ifc.Index, "Index mismatch") assert.Equal(t, 16384, ifc.MTU, "MTU mismatch") assert.Equal(t, net.HardwareAddr(nil), ifc.HardwareAddr, "HardwareAddr mismatch") assert.Equal(t, net.FlagUp|net.FlagLoopback|net.FlagMulticast, ifc.Flags, "Flags mismatch") addrs, err2 := ifc.Addrs() assert.NoError(t, err2, "should succeed") assert.Equal(t, 1, len(addrs), "should be one address") } ifc, err = nw.InterfaceByName("eth0") assert.NoError(t, err, "should succeed") assert.Equal(t, 2, ifc.Index, "Index mismatch") assert.Equal(t, 1500, ifc.MTU, "MTU mismatch") assert.Equal(t, 6, len(ifc.HardwareAddr), "HardwareAddr length mismatch") assert.Equal(t, net.FlagUp|net.FlagMulticast, ifc.Flags, "Flags mismatch") _, err = ifc.Addrs() assert.NotNil(t, err, "should fail") _, err = nw.InterfaceByName("foo0") assert.NotNil(t, err, "should fail") }) t.Run("hasIPAddr", func(t *testing.T) { nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } intfs, err := nw.Interfaces() assert.Equal(t, 2, len(intfs), "should be one tnet.Interface") assert.NoError(t, err, "should succeed") var ifc *transport.Interface ifc, err = nw.InterfaceByName("eth0") assert.NoError(t, err, "should succeed") ifc.AddAddress(&net.IPNet{ IP: net.ParseIP("10.1.2.3"), Mask: net.CIDRMask(24, 32), }) _, err = ifc.Addrs() assert.NoError(t, err, "should succeed") assert.True(t, nw.hasIPAddr(net.ParseIP("127.0.0.1")), "the IP addr should exist") assert.True(t, nw.hasIPAddr(net.ParseIP("10.1.2.3")), "the IP addr should exist") assert.False(t, nw.hasIPAddr(net.ParseIP("192.168.1.1")), "the IP addr should NOT exist") }) t.Run("getAllIPAddrs", func(t *testing.T) { nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } intfs, err := nw.Interfaces() assert.Equal(t, 2, len(intfs), "should be one tnet.Interface") assert.NoError(t, err, "should succeed") var ifc *transport.Interface ifc, err = nw.InterfaceByName("eth0") assert.NoError(t, err, "should succeed") ifc.AddAddress(&net.IPNet{ IP: net.ParseIP("10.1.2.3"), Mask: net.CIDRMask(24, 32), }) ips := nw.getAllIPAddrs(false) assert.Equal(t, 2, len(ips), "should match") for _, ip := range ips { log.Debugf("ip: %s", ip.String()) } }) t.Run("assignPort()", func(t *testing.T) { nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } addr := demoIP start := 1000 end := 1002 space := end + 1 - start intfs, err := nw.Interfaces() assert.Equal(t, 2, len(intfs), "should be one tnet.Interface") assert.NoError(t, err, "should succeed") var ifc *transport.Interface ifc, err = nw.InterfaceByName("eth0") assert.NoError(t, err, "should succeed") ifc.AddAddress(&net.IPNet{ IP: net.ParseIP(addr), Mask: net.CIDRMask(24, 32), }) // attempt to assign port with start > end should fail _, err = nw.assignPort(net.ParseIP(addr), 3000, 2999) assert.NotNil(t, err, "should fail") for i := 0; i < space; i++ { port, err2 := nw.assignPort(net.ParseIP(addr), start, end) assert.NoError(t, err2, "should succeed") log.Debugf("[%d] got port: %d", i, port) conn, err2 := newUDPConn(&net.UDPAddr{ IP: net.ParseIP(addr), Port: port, }, nil, &myConnObserver{}) assert.NoError(t, err2, "should succeed") err2 = nw.udpConns.insert(conn) assert.NoError(t, err2, "should succeed") } assert.Equal(t, space, nw.udpConns.size(), "should match") // attempt to assign again should fail _, err = nw.assignPort(net.ParseIP(addr), start, end) assert.NotNil(t, err, "should fail") }) t.Run("determineSourceIP()", func(t *testing.T) { nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } intfs, err := nw.Interfaces() assert.Equal(t, 2, len(intfs), "should be one tnet.Interface") assert.NoError(t, err, "should succeed") var ifc *transport.Interface ifc, err = nw.InterfaceByName("eth0") assert.NoError(t, err, "should succeed") ifc.AddAddress(&net.IPNet{ IP: net.ParseIP(demoIP), Mask: net.CIDRMask(24, 32), }) // Any IP turned into non-loopback IP anyIP := net.ParseIP("0.0.0.0") dstIP := net.ParseIP("27.1.7.135") srcIP := nw.determineSourceIP(anyIP, dstIP) log.Debugf("anyIP: %s => %s", anyIP.String(), srcIP.String()) assert.NotNil(t, srcIP, "shouldn't be nil") assert.Equal(t, srcIP.String(), demoIP, "use non-loopback IP") // Any IP turned into loopback IP anyIP = net.ParseIP("0.0.0.0") dstIP = net.ParseIP("127.0.0.2") srcIP = nw.determineSourceIP(anyIP, dstIP) log.Debugf("anyIP: %s => %s", anyIP.String(), srcIP.String()) assert.NotNil(t, srcIP, "shouldn't be nil") assert.Equal(t, srcIP.String(), "127.0.0.1", "use loopback IP") // Non any IP won't change anyIP = net.ParseIP(demoIP) dstIP = net.ParseIP("127.0.0.2") srcIP = nw.determineSourceIP(anyIP, dstIP) log.Debugf("anyIP: %s => %s", anyIP.String(), srcIP.String()) assert.NotNil(t, srcIP, "shouldn't be nil") assert.True(t, srcIP.Equal(anyIP), "IP change") }) t.Run("ResolveUDPAddr", func(t *testing.T) { nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } udpAddr, err := nw.ResolveUDPAddr(udp, "localhost:1234") if !assert.NoError(t, err, "should succeed") { return } assert.Equal(t, "127.0.0.1", udpAddr.IP.String(), "should match") assert.Equal(t, 1234, udpAddr.Port, "should match") }) t.Run("UDPLoopback", func(t *testing.T) { nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } conn, err := nw.ListenPacket(udp, "127.0.0.1:0") assert.NoError(t, err, "should succeed") laddr := conn.LocalAddr() msg := "PING!" n, err := conn.WriteTo([]byte(msg), laddr) assert.NoError(t, err, "should succeed") assert.Equal(t, len(msg), n, "should match") buf := make([]byte, 1000) n, addr, err := conn.ReadFrom(buf) assert.NoError(t, err, "should succeed") assert.Equal(t, len(msg), n, "should match") assert.Equal(t, msg, string(buf[:n]), "should match") assert.Equal(t, laddr.(*net.UDPAddr).String(), addr.(*net.UDPAddr).String(), "should match") //nolint:forcetypeassert assert.Equal(t, 1, nw.udpConns.size(), "should match") assert.NoError(t, conn.Close(), "should succeed") assert.Empty(t, nw.udpConns.size(), "should match") }) t.Run("ListenPacket random port", func(t *testing.T) { nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } conn, err := nw.ListenPacket(udp, "127.0.0.1:0") assert.NoError(t, err, "should succeed") laddr := conn.LocalAddr().String() log.Debugf("laddr: %s", laddr) assert.Equal(t, 1, nw.udpConns.size(), "should match") assert.NoError(t, conn.Close(), "should succeed") assert.Empty(t, nw.udpConns.size(), "should match") }) t.Run("ListenPacket specific port", func(t *testing.T) { nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } conn, err := nw.ListenPacket(udp, "127.0.0.1:50916") assert.NoError(t, err, "should succeed") laddr := conn.LocalAddr().String() assert.Equal(t, "127.0.0.1:50916", laddr, "should match") assert.Equal(t, 1, nw.udpConns.size(), "should match") assert.NoError(t, conn.Close(), "should succeed") assert.Empty(t, nw.udpConns.size(), "should match") }) t.Run("ListenUDP random port", func(t *testing.T) { nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } srcAddr := &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), } conn, err := nw.ListenUDP(udp, srcAddr) assert.NoError(t, err, "should succeed") laddr := conn.LocalAddr().String() log.Debugf("laddr: %s", laddr) assert.Equal(t, 1, nw.udpConns.size(), "should match") assert.NoError(t, conn.Close(), "should succeed") assert.Empty(t, nw.udpConns.size(), "should match") }) t.Run("ListenUDP specific port", func(t *testing.T) { nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } srcAddr := &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: 60916, } conn, err := nw.ListenUDP(udp, srcAddr) assert.NoError(t, err, "should succeed") laddr := conn.LocalAddr().String() assert.Equal(t, "127.0.0.1:60916", laddr, "should match") assert.Equal(t, 1, nw.udpConns.size(), "should match") assert.NoError(t, conn.Close(), "should succeed") assert.Empty(t, nw.udpConns.size(), "should match") }) t.Run("Dial (UDP) lo0", func(t *testing.T) { nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } conn, err := nw.Dial(udp, "127.0.0.1:1234") assert.NoError(t, err, "should succeed") laddr := conn.LocalAddr() log.Debugf("laddr: %s", laddr.String()) raddr := conn.RemoteAddr() log.Debugf("raddr: %s", raddr.String()) assert.Equal(t, "127.0.0.1", laddr.(*net.UDPAddr).IP.String(), "should match") //nolint:forcetypeassert assert.True(t, laddr.(*net.UDPAddr).Port != 0, "should match") //nolint:forcetypeassert assert.Equal(t, "127.0.0.1:1234", raddr.String(), "should match") assert.Equal(t, 1, nw.udpConns.size(), "should match") assert.NoError(t, conn.Close(), "should succeed") assert.Empty(t, nw.udpConns.size(), "should match") }) t.Run("Dial (UDP) eth0", func(t *testing.T) { wan, err := NewRouter(&RouterConfig{ CIDR: "1.2.3.0/24", LoggerFactory: loggerFactory, }) assert.NoError(t, err, "should succeed") nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } assert.NoError(t, wan.AddNet(nw), "should succeed") conn, err := nw.Dial(udp, "27.3.4.5:1234") assert.NoError(t, err, "should succeed") laddr := conn.LocalAddr() log.Debugf("laddr: %s", laddr.String()) raddr := conn.RemoteAddr() log.Debugf("raddr: %s", raddr.String()) assert.Equal(t, "1.2.3.1", laddr.(*net.UDPAddr).IP.String(), "should match") //nolint:forcetypeassert assert.True(t, laddr.(*net.UDPAddr).Port != 0, "should match") //nolint:forcetypeassert assert.Equal(t, "27.3.4.5:1234", raddr.String(), "should match") assert.Equal(t, 1, nw.udpConns.size(), "should match") assert.NoError(t, conn.Close(), "should succeed") assert.Empty(t, nw.udpConns.size(), "should match") }) t.Run("DialUDP", func(t *testing.T) { nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } locAddr := &net.UDPAddr{ IP: net.IPv4(127, 0, 0, 1), Port: 0, } remAddr := &net.UDPAddr{ IP: net.IPv4(127, 0, 0, 1), Port: 1234, } conn, err := nw.DialUDP(udp, locAddr, remAddr) assert.NoError(t, err, "should succeed") laddr := conn.LocalAddr() log.Debugf("laddr: %s", laddr.String()) raddr := conn.RemoteAddr() log.Debugf("raddr: %s", raddr.String()) assert.Equal(t, "127.0.0.1", laddr.(*net.UDPAddr).IP.String(), "should match") //nolint:forcetypeassert assert.True(t, laddr.(*net.UDPAddr).Port != 0, "should match") //nolint:forcetypeassert assert.Equal(t, "127.0.0.1:1234", raddr.String(), "should match") assert.Equal(t, 1, nw.udpConns.size(), "should match") assert.NoError(t, conn.Close(), "should succeed") assert.Empty(t, nw.udpConns.size(), "should match") }) t.Run("Resolver", func(t *testing.T) { wan, err := NewRouter(&RouterConfig{ CIDR: "1.2.3.0/24", LoggerFactory: loggerFactory, }) if !assert.NoError(t, err, "should succeed") { return } err = wan.AddHost("test.pion.ly", "30.31.32.33") assert.NoError(t, err, "should succeed") nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } assert.NoError(t, wan.AddNet(nw), "should succeed") conn, err := nw.Dial(udp, "test.pion.ly:1234") assert.NoError(t, err, "should succeed") laddr := conn.LocalAddr() log.Debugf("laddr: %s", laddr.String()) raddr := conn.RemoteAddr() log.Debugf("raddr: %s", raddr.String()) assert.Equal(t, "1.2.3.1", laddr.(*net.UDPAddr).IP.String(), "should match") //nolint:forcetypeassert assert.True(t, laddr.(*net.UDPAddr).Port != 0, "should match") //nolint:forcetypeassert assert.Equal(t, "30.31.32.33:1234", raddr.String(), "should match") assert.Equal(t, 1, nw.udpConns.size(), "should match") assert.NoError(t, conn.Close(), "should succeed") assert.Empty(t, nw.udpConns.size(), "should match") }) t.Run("Loopback", func(t *testing.T) { nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } conn, err := nw.ListenPacket(udp, "127.0.0.1:50916") assert.NoError(t, err, "should succeed") laddr := conn.LocalAddr() assert.Equal(t, "127.0.0.1:50916", laddr.String(), "should match") chunk := newChunkUDP(&net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: 4000, }, &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: 50916, }) chunk.userData = []byte("Hello!") var hasReceived bool recvdCh := make(chan bool) doneCh := make(chan struct{}) go func() { var err error var n int var addr net.Addr buf := make([]byte, 1500) for { n, addr, err = conn.ReadFrom(buf) if err != nil { log.Debugf("ReadFrom returned: %v", err) break } assert.Equal(t, 6, len(chunk.userData), "should match") assert.Equal(t, "127.0.0.1:4000", addr.String(), "should match") assert.Equal(t, "Hello!", string(buf[:n]), "should match") recvdCh <- true } close(doneCh) }() nw.onInboundChunk(chunk) loop: for { select { case <-recvdCh: hasReceived = true assert.NoError(t, conn.Close(), "should succeed") case <-doneCh: break loop } } assert.Empty(t, nw.udpConns.size(), "should match") assert.True(t, hasReceived, "should have received data") }) t.Run("End-to-End", func(t *testing.T) { doneCh := make(chan struct{}) // WAN wan, err := NewRouter(&RouterConfig{ CIDR: "1.2.3.0/24", LoggerFactory: loggerFactory, }) assert.NoError(t, err, "should succeed") net1, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } err = wan.AddNet(net1) assert.NoError(t, err, "should succeed") ip1, err := getIPAddr(net1) assert.NoError(t, err, "should succeed") net2, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } err = wan.AddNet(net2) assert.NoError(t, err, "should succeed") ip2, err := getIPAddr(net2) assert.NoError(t, err, "should succeed") conn1, err := net1.ListenPacket( udp, fmt.Sprintf("%s:%d", ip1, 1234), ) assert.NoError(t, err, "should succeed") conn2, err := net2.ListenPacket( udp, fmt.Sprintf("%s:%d", ip2, 5678), ) assert.NoError(t, err, "should succeed") // start the router err = wan.Start() assert.NoError(t, err, "should succeed") conn1RcvdCh := make(chan bool) // conn1 go func() { buf := make([]byte, 1500) for { log.Debug("conn1: wait for a message..") n, _, err2 := conn1.ReadFrom(buf) if err2 != nil { log.Debugf("ReadFrom returned: %v", err2) break } log.Debugf("conn1 received %s", string(buf[:n])) conn1RcvdCh <- true } close(doneCh) }() // conn2 go func() { buf := make([]byte, 1500) for { log.Debug("conn2: wait for a message..") n, addr, err2 := conn2.ReadFrom(buf) if err2 != nil { log.Debugf("ReadFrom returned: %v", err2) break } log.Debugf("conn2 received %s", string(buf[:n])) // echo back to conn1 nSent, err2 := conn2.WriteTo([]byte("Good-bye!"), addr) assert.NoError(t, err2, "should succeed") assert.Equal(t, 9, nSent, "should match") } }() log.Debug("conn1: sending") nSent, err := conn1.WriteTo( []byte("Hello!"), conn2.LocalAddr(), ) assert.NoError(t, err, "should succeed") assert.Equal(t, 6, nSent, "should match") loop: for { select { case <-conn1RcvdCh: assert.NoError(t, conn1.Close(), "should succeed") assert.NoError(t, conn2.Close(), "should succeed") case <-doneCh: break loop } } assert.NoError(t, wan.Stop(), "should succeed") }) t.Run("Dialer", func(t *testing.T) { nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } dialer := nw.CreateDialer(&net.Dialer{ LocalAddr: &net.UDPAddr{ IP: net.ParseIP("127.0.0.1"), Port: 0, }, }) conn, err := dialer.Dial(udp, "127.0.0.1:1234") assert.NoError(t, err, "should succeed") laddr := conn.LocalAddr() log.Debugf("laddr: %s", laddr.String()) raddr := conn.RemoteAddr() log.Debugf("raddr: %s", raddr.String()) assert.Equal(t, "127.0.0.1", laddr.(*net.UDPAddr).IP.String(), "should match") //nolint:forcetypeassert assert.True(t, laddr.(*net.UDPAddr).Port != 0, "should match") //nolint:forcetypeassert assert.Equal(t, "127.0.0.1:1234", raddr.String(), "should match") assert.Equal(t, 1, nw.udpConns.size(), "should match") assert.NoError(t, conn.Close(), "should succeed") assert.Empty(t, nw.udpConns.size(), "should match") }) t.Run("Two IPs on a NIC", func(t *testing.T) { doneCh := make(chan struct{}) // WAN wan, err := NewRouter(&RouterConfig{ CIDR: "1.2.3.0/24", LoggerFactory: loggerFactory, }) if !assert.NoError(t, err, "should succeed") { return } net1, err := NewNet(&NetConfig{ StaticIPs: []string{ demoIP, "1.2.3.5", }, }) if !assert.NoError(t, err, "should succeed") { return } err = wan.AddNet(net1) assert.NoError(t, err, "should succeed") // start the router err = wan.Start() assert.NoError(t, err, "should succeed") conn1, err := net1.ListenPacket(udp, "1.2.3.4:1234") assert.NoError(t, err, "should succeed") conn2, err := net1.ListenPacket(udp, "1.2.3.5:1234") assert.NoError(t, err, "should succeed") conn1RcvdCh := make(chan bool) // conn1 go func() { buf := make([]byte, 1500) for { log.Debug("conn1: wait for a message..") n, _, err2 := conn1.ReadFrom(buf) if err2 != nil { log.Debugf("ReadFrom returned: %v", err2) break } log.Debugf("conn1 received %s", string(buf[:n])) conn1RcvdCh <- true } close(doneCh) }() // conn2 go func() { buf := make([]byte, 1500) for { log.Debug("conn2: wait for a message..") n, addr, err2 := conn2.ReadFrom(buf) if err2 != nil { log.Debugf("ReadFrom returned: %v", err2) break } log.Debugf("conn2 received %s", string(buf[:n])) // echo back to conn1 nSent, err2 := conn2.WriteTo([]byte("Good-bye!"), addr) assert.NoError(t, err2, "should succeed") assert.Equal(t, 9, nSent, "should match") } }() log.Debug("conn1: sending") nSent, err := conn1.WriteTo( []byte("Hello!"), conn2.LocalAddr(), ) assert.NoError(t, err, "should succeed") assert.Equal(t, 6, nSent, "should match") loop: for { select { case <-conn1RcvdCh: assert.NoError(t, conn1.Close(), "should succeed") assert.NoError(t, conn2.Close(), "should succeed") case <-doneCh: break loop } } assert.NoError(t, wan.Stop(), "should succeed") }) } golang-github-pion-transport-v3-3.0.8/vnet/resolver.go000066400000000000000000000034551507057301700227540ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "errors" "fmt" "net" "sync" "github.com/pion/logging" ) var ( errHostnameEmpty = errors.New("host name must not be empty") errFailedToParseIPAddr = errors.New("failed to parse IP address") ) type resolverConfig struct { LoggerFactory logging.LoggerFactory } type resolver struct { parent *resolver // read-only hosts map[string]net.IP // requires mutex mutex sync.RWMutex // thread-safe log logging.LeveledLogger // read-only } func newResolver(config *resolverConfig) *resolver { r := &resolver{ hosts: map[string]net.IP{}, log: config.LoggerFactory.NewLogger("vnet"), } if err := r.addHost("localhost", "127.0.0.1"); err != nil { r.log.Warn("failed to add localhost to resolver") } return r } func (r *resolver) setParent(parent *resolver) { r.mutex.Lock() defer r.mutex.Unlock() r.parent = parent } func (r *resolver) addHost(name string, ipAddr string) error { r.mutex.Lock() defer r.mutex.Unlock() if len(name) == 0 { return errHostnameEmpty } ip := net.ParseIP(ipAddr) if ip == nil { return fmt.Errorf("%w \"%s\"", errFailedToParseIPAddr, ipAddr) } r.hosts[name] = ip return nil } func (r *resolver) lookUp(hostName string) (net.IP, error) { ip := func() net.IP { r.mutex.RLock() defer r.mutex.RUnlock() if ip2, ok := r.hosts[hostName]; ok { return ip2 } return nil }() if ip != nil { return ip, nil } // mutex must be unlocked before calling into parent resolver if r.parent != nil { return r.parent.lookUp(hostName) } return nil, &net.DNSError{ Err: "host not found", Name: hostName, Server: "vnet resolver", IsTimeout: false, IsTemporary: false, } } golang-github-pion-transport-v3-3.0.8/vnet/resolver_test.go000066400000000000000000000036431507057301700240120ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "net" "testing" "github.com/pion/logging" "github.com/stretchr/testify/assert" ) func TestResolver(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") t.Run("Standalone", func(t *testing.T) { resolver := newResolver(&resolverConfig{ LoggerFactory: loggerFactory, }) // should have localhost by default name := "localhost" ipAddr := "127.0.0.1" ip := net.ParseIP(ipAddr) resolved, err := resolver.lookUp(name) assert.NoError(t, err, "should succeed") assert.True(t, resolved.Equal(ip), "should match") name = "abc.com" ipAddr = demoIP ip = net.ParseIP(ipAddr) log.Debugf("adding %s %s", name, ipAddr) err = resolver.addHost(name, ipAddr) assert.NoError(t, err, "should succeed") resolved, err = resolver.lookUp(name) assert.NoError(t, err, "should succeed") assert.True(t, resolved.Equal(ip), "should match") }) t.Run("Cascaded", func(t *testing.T) { r0 := newResolver(&resolverConfig{ LoggerFactory: loggerFactory, }) name0 := "abc.com" ipAddr0 := demoIP ip0 := net.ParseIP(ipAddr0) err := r0.addHost(name0, ipAddr0) assert.NoError(t, err, "should succeed") r1 := newResolver(&resolverConfig{ LoggerFactory: loggerFactory, }) name1 := "myserver.local" ipAddr1 := "10.1.2.5" ip1 := net.ParseIP(ipAddr1) err = r1.addHost(name1, ipAddr1) assert.NoError(t, err, "should succeed") r1.setParent(r0) resolved, err := r1.lookUp(name0) assert.NoError(t, err, "should succeed") assert.True(t, resolved.Equal(ip0), "should match") resolved, err = r1.lookUp(name1) assert.NoError(t, err, "should succeed") assert.True(t, resolved.Equal(ip1), "should match") // should fail if the name does not exist _, err = r1.lookUp("bad.com") assert.NotNil(t, err, "should fail") }) } golang-github-pion-transport-v3-3.0.8/vnet/router.go000066400000000000000000000350041507057301700224260ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "errors" "fmt" "math/rand" "net" "strings" "sync" "sync/atomic" "time" "github.com/pion/logging" "github.com/pion/transport/v3" ) const ( defaultRouterQueueSize = 0 // unlimited ) var ( errInvalidLocalIPinStaticIPs = errors.New("invalid local IP in StaticIPs") errLocalIPBeyondStaticIPsSubset = errors.New("mapped in StaticIPs is beyond subnet") errLocalIPNoStaticsIPsAssociated = errors.New("all StaticIPs must have associated local IPs") errRouterAlreadyStarted = errors.New("router already started") errRouterAlreadyStopped = errors.New("router already stopped") errStaticIPisBeyondSubnet = errors.New("static IP is beyond subnet") errAddressSpaceExhausted = errors.New("address space exhausted") errNoIPAddrEth0 = errors.New("no IP address is assigned for eth0") ) // Generate a unique router name. var assignRouterName = func() func() string { //nolint:gochecknoglobals var routerIDCtr uint64 return func() string { n := atomic.AddUint64(&routerIDCtr, 1) return fmt.Sprintf("router%d", n) } }() // RouterConfig ... type RouterConfig struct { // Name of router. If not specified, a unique name will be assigned. Name string // CIDR notation, like "192.0.2.0/24" CIDR string // StaticIPs is an array of static IP addresses to be assigned for this router. // If no static IP address is given, the router will automatically assign // an IP address. // This will be ignored if this router is the root. StaticIPs []string // StaticIP is deprecated. Use StaticIPs. StaticIP string // Internal queue size QueueSize int // Effective only when this router has a parent router NATType *NATType // Minimum Delay MinDelay time.Duration // Max Jitter MaxJitter time.Duration // Logger factory LoggerFactory logging.LoggerFactory } // NIC is a network interface controller that interfaces Router. type NIC interface { getInterface(ifName string) (*transport.Interface, error) onInboundChunk(c Chunk) getStaticIPs() []net.IP setRouter(r *Router) error } // ChunkFilter is a handler users can add to filter chunks. // If the filter returns false, the packet will be dropped. type ChunkFilter func(c Chunk) bool // Router ... type Router struct { name string // read-only interfaces []*transport.Interface // read-only ipv4Net *net.IPNet // read-only staticIPs []net.IP // read-only staticLocalIPs map[string]net.IP // read-only, lastID byte // requires mutex [x], used to assign the last digit of IPv4 address queue *chunkQueue // read-only parent *Router // read-only children []*Router // read-only natType *NATType // read-only nat *networkAddressTranslator // read-only nics map[string]NIC // read-only stopFunc func() // requires mutex [x] resolver *resolver // read-only chunkFilters []ChunkFilter // requires mutex [x] minDelay time.Duration // requires mutex [x] maxJitter time.Duration // requires mutex [x] mutex sync.RWMutex // thread-safe pushCh chan struct{} // writer requires mutex loggerFactory logging.LoggerFactory // read-only log logging.LeveledLogger // read-only } // NewRouter ... func NewRouter(config *RouterConfig) (*Router, error) { //nolint:cyclop loggerFactory := config.LoggerFactory log := loggerFactory.NewLogger("vnet") _, ipv4Net, err := net.ParseCIDR(config.CIDR) if err != nil { return nil, err } queueSize := defaultRouterQueueSize if config.QueueSize > 0 { queueSize = config.QueueSize } // set up network interface, lo0 lo0 := transport.NewInterface(net.Interface{ Index: 1, MTU: 16384, Name: lo0String, HardwareAddr: nil, Flags: net.FlagUp | net.FlagLoopback | net.FlagMulticast, }) lo0.AddAddress(&net.IPAddr{IP: net.ParseIP("127.0.0.1"), Zone: ""}) // set up network interface, eth0 eth0 := transport.NewInterface(net.Interface{ Index: 2, MTU: 1500, Name: "eth0", HardwareAddr: newMACAddress(), Flags: net.FlagUp | net.FlagMulticast, }) // local host name resolver resolver := newResolver(&resolverConfig{ LoggerFactory: config.LoggerFactory, }) name := config.Name if len(name) == 0 { name = assignRouterName() } var staticIPs []net.IP staticLocalIPs := map[string]net.IP{} for _, ipStr := range config.StaticIPs { ipPair := strings.Split(ipStr, "/") if ip := net.ParseIP(ipPair[0]); ip != nil { //nolint:nestif if len(ipPair) > 1 { locIP := net.ParseIP(ipPair[1]) if locIP == nil { return nil, errInvalidLocalIPinStaticIPs } if !ipv4Net.Contains(locIP) { return nil, fmt.Errorf("local IP %s %w", locIP.String(), errLocalIPBeyondStaticIPsSubset) } staticLocalIPs[ip.String()] = locIP } staticIPs = append(staticIPs, ip) } } if len(config.StaticIP) > 0 { log.Warn("StaticIP is deprecated. Use StaticIPs instead") if ip := net.ParseIP(config.StaticIP); ip != nil { staticIPs = append(staticIPs, ip) } } if nStaticLocal := len(staticLocalIPs); nStaticLocal > 0 { if nStaticLocal != len(staticIPs) { return nil, errLocalIPNoStaticsIPsAssociated } } return &Router{ name: name, interfaces: []*transport.Interface{lo0, eth0}, ipv4Net: ipv4Net, staticIPs: staticIPs, staticLocalIPs: staticLocalIPs, queue: newChunkQueue(queueSize, 0), natType: config.NATType, nics: map[string]NIC{}, resolver: resolver, minDelay: config.MinDelay, maxJitter: config.MaxJitter, pushCh: make(chan struct{}, 1), loggerFactory: loggerFactory, log: log, }, nil } // caller must hold the mutex. func (r *Router) getInterfaces() ([]*transport.Interface, error) { if len(r.interfaces) == 0 { return nil, fmt.Errorf("%w is available", errNoInterface) } return r.interfaces, nil } func (r *Router) getInterface(ifName string) (*transport.Interface, error) { r.mutex.RLock() defer r.mutex.RUnlock() ifs, err := r.getInterfaces() if err != nil { return nil, err } for _, ifc := range ifs { if ifc.Name == ifName { return ifc, nil } } return nil, fmt.Errorf("%w: %s", transport.ErrInterfaceNotFound, ifName) } // Start ... func (r *Router) Start() error { //nolint:cyclop r.mutex.Lock() defer r.mutex.Unlock() if r.stopFunc != nil { return errRouterAlreadyStarted } cancelCh := make(chan struct{}) go func() { loop: for { duration, err := r.processChunks() if err != nil { r.log.Errorf("[%s] %s", r.name, err.Error()) break } if duration <= 0 { select { case <-r.pushCh: case <-cancelCh: break loop } } else { t := time.NewTimer(duration) select { case <-t.C: case <-cancelCh: break loop } } } }() r.stopFunc = func() { close(cancelCh) } for _, child := range r.children { if err := child.Start(); err != nil { return err } } return nil } // Stop ... func (r *Router) Stop() error { r.mutex.Lock() defer r.mutex.Unlock() if r.stopFunc == nil { return errRouterAlreadyStopped } for _, router := range r.children { r.mutex.Unlock() err := router.Stop() r.mutex.Lock() if err != nil { return err } } r.stopFunc() r.stopFunc = nil return nil } // caller must hold the mutex. func (r *Router) addNIC(nic NIC) error { ifc, err := nic.getInterface("eth0") if err != nil { return err } var ips []net.IP if ips = nic.getStaticIPs(); len(ips) == 0 { // assign an IP address ip, err2 := r.assignIPAddress() if err2 != nil { return err2 } ips = append(ips, ip) } for _, ip := range ips { if !r.ipv4Net.Contains(ip) { return fmt.Errorf("%w: %s", errStaticIPisBeyondSubnet, r.ipv4Net.String()) } ifc.AddAddress(&net.IPNet{ IP: ip, Mask: r.ipv4Net.Mask, }) r.nics[ip.String()] = nic } return nic.setRouter(r) } // AddRouter adds a child Router. func (r *Router) AddRouter(router *Router) error { r.mutex.Lock() defer r.mutex.Unlock() // Router is a NIC. Add it as a NIC so that packets are routed to this child // router. err := r.addNIC(router) if err != nil { return err } if err = router.setRouter(r); err != nil { return err } r.children = append(r.children, router) return nil } // AddChildRouter is like AddRouter, but does not add the child routers NIC to // the parent. This has to be done manually by calling AddNet, which allows to // use a wrapper around the subrouters NIC. // AddNet MUST be called before AddChildRouter. func (r *Router) AddChildRouter(router *Router) error { r.mutex.Lock() defer r.mutex.Unlock() if err := router.setRouter(r); err != nil { return err } r.children = append(r.children, router) return nil } // AddNet ... func (r *Router) AddNet(nic NIC) error { r.mutex.Lock() defer r.mutex.Unlock() return r.addNIC(nic) } // AddHost adds a mapping of hostname and an IP address to the local resolver. func (r *Router) AddHost(hostName string, ipAddr string) error { return r.resolver.addHost(hostName, ipAddr) } // AddChunkFilter adds a filter for chunks traversing this router. // You may add more than one filter. The filters are called in the order of this method call. // If a chunk is dropped by a filter, subsequent filter will not receive the chunk. func (r *Router) AddChunkFilter(filter ChunkFilter) { r.mutex.Lock() defer r.mutex.Unlock() r.chunkFilters = append(r.chunkFilters, filter) } // caller should hold the mutex. func (r *Router) assignIPAddress() (net.IP, error) { // See: https://stackoverflow.com/questions/14915188/ip-address-ending-with-zero if r.lastID == 0xfe { return nil, errAddressSpaceExhausted } ip := make(net.IP, 4) copy(ip, r.ipv4Net.IP[:3]) r.lastID++ ip[3] = r.lastID return ip, nil } func (r *Router) push(c Chunk) { r.mutex.Lock() defer r.mutex.Unlock() r.log.Debugf("[%s] route %s", r.name, c.String()) if r.stopFunc != nil { c.setTimestamp() if r.queue.push(c) { select { case r.pushCh <- struct{}{}: default: } } else { r.log.Warnf("[%s] queue was full. dropped a chunk", r.name) } } } func (r *Router) processChunks() (time.Duration, error) { //nolint:cyclop r.mutex.Lock() defer r.mutex.Unlock() // Introduce jitter by delaying the processing of chunks. if r.maxJitter > 0 { jitter := time.Duration(rand.Int63n(int64(r.maxJitter))) //nolint:gosec time.Sleep(jitter) } // cutOff // v min delay // |<--->| // +------------:-- // |OOOOOOXXXXX : --> time // +------------:-- // |<--->| now // due enteredAt := time.Now() cutOff := enteredAt.Add(-r.minDelay) var duration time.Duration // the next sleep duration for { duration = 0 chunk := r.queue.peek() if chunk == nil { break // no more chunk in the queue } // check timestamp to find if the chunk is due if chunk.getTimestamp().After(cutOff) { // There is one or more chunk in the queue but none of them are due. // Calculate the next sleep duration here. nextExpire := chunk.getTimestamp().Add(r.minDelay) duration = nextExpire.Sub(enteredAt) break } var ok bool if chunk, ok = r.queue.pop(); !ok { break // no more chunk in the queue } blocked := false for i := 0; i < len(r.chunkFilters); i++ { filter := r.chunkFilters[i] if !filter(chunk) { blocked = true break } } if blocked { continue // discard } dstIP := chunk.getDestinationIP() // check if the destination is in our subnet if r.ipv4Net.Contains(dstIP) { // search for the destination NIC var nic NIC if nic, ok = r.nics[dstIP.String()]; !ok { // NIC not found. drop it. r.log.Debugf("[%s] %s unreachable", r.name, chunk.String()) continue } // found the NIC, forward the chunk to the NIC. // call to NIC must unlock mutex r.mutex.Unlock() nic.onInboundChunk(chunk) r.mutex.Lock() continue } // the destination is outside of this subnet // is this WAN? if r.parent == nil { // this WAN. No route for this chunk r.log.Debugf("[%s] no route found for %s", r.name, chunk.String()) continue } // Pass it to the parent via NAT toParent, err := r.nat.translateOutbound(chunk) if err != nil { return 0, err } if toParent == nil { continue } //nolint:godox /* FIXME: this implementation would introduce a duplicate packet! if r.nat.natType.Hairpinning { hairpinned, err := r.nat.translateInbound(toParent) if err != nil { r.log.Warnf("[%s] %s", r.name, err.Error()) } else { go func() { r.push(hairpinned) }() } } */ // call to parent router mutex unlock mutex r.mutex.Unlock() r.parent.push(toParent) r.mutex.Lock() } return duration, nil } // caller must hold the mutex. func (r *Router) setRouter(parent *Router) error { //nolint:cyclop r.parent = parent r.resolver.setParent(parent.resolver) // when this method is called, one or more IP address has already been assigned by // the parent router. ifc, err := r.getInterface("eth0") if err != nil { return err } addrs, _ := ifc.Addrs() if len(addrs) == 0 { return errNoIPAddrEth0 } mappedIPs := []net.IP{} localIPs := []net.IP{} for _, ifcAddr := range addrs { var ip net.IP switch addr := ifcAddr.(type) { case *net.IPNet: ip = addr.IP case *net.IPAddr: // Do we really need this case? ip = addr.IP default: } if ip == nil { continue } mappedIPs = append(mappedIPs, ip) if locIP := r.staticLocalIPs[ip.String()]; locIP != nil { localIPs = append(localIPs, locIP) } } // Set up NAT here if r.natType == nil { r.natType = &NATType{ MappingBehavior: EndpointIndependent, FilteringBehavior: EndpointAddrPortDependent, Hairpinning: false, PortPreservation: false, MappingLifeTime: 30 * time.Second, } } r.nat, err = newNAT(&natConfig{ name: r.name, natType: *r.natType, mappedIPs: mappedIPs, localIPs: localIPs, loggerFactory: r.loggerFactory, }) if err != nil { return err } return nil } func (r *Router) onInboundChunk(c Chunk) { fromParent, err := r.nat.translateInbound(c) if err != nil { r.log.Warnf("[%s] %s", r.name, err.Error()) return } r.push(fromParent) } func (r *Router) getStaticIPs() []net.IP { return r.staticIPs } golang-github-pion-transport-v3-3.0.8/vnet/router_test.go000066400000000000000000000441741507057301700234750ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "errors" "net" "sync/atomic" "testing" "time" "github.com/pion/logging" "github.com/stretchr/testify/assert" ) var errNoAddress = errors.New("there must be one address") type dummyNIC struct { *Net onInboundChunkHandler func(Chunk) } // hijack onInboundChunk. func (v *dummyNIC) onInboundChunk(c Chunk) { v.onInboundChunkHandler(c) } func getIPAddr(n NIC) (string, error) { eth0, err := n.getInterface("eth0") if err != nil { return "", err } addrs, err := eth0.Addrs() if err != nil { return "", err } if len(addrs) != 1 { return "", errNoAddress } return addrs[0].(*net.IPNet).IP.String(), nil //nolint:forcetypeassert } func TestRouterStandalone(t *testing.T) { //nolint:cyclop,maintidx loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") t.Run("CIDR parsing", func(t *testing.T) { r, err := NewRouter(&RouterConfig{ CIDR: "1.2.3.0/24", LoggerFactory: loggerFactory, }) if !assert.Nil(t, err, "should succeed") { return } assert.Equal(t, "1.2.3.0", r.ipv4Net.IP.String(), "ip should match") assert.Equal(t, "ffffff00", r.ipv4Net.Mask.String(), "mask should match") }) t.Run("assignIPAddress", func(t *testing.T) { router, err := NewRouter(&RouterConfig{ CIDR: "1.2.3.0/24", LoggerFactory: loggerFactory, }) if !assert.Nil(t, err, "should succeed") { return } for i := 1; i < 255; i++ { ip, err2 := router.assignIPAddress() assert.Nil(t, err2, "should succeed") assert.Equal(t, byte(1), ip[0], "should match") assert.Equal(t, byte(2), ip[1], "should match") assert.Equal(t, byte(3), ip[2], "should match") assert.Equal(t, byte(i), ip[3], "should match") } _, err = router.assignIPAddress() assert.NotNil(t, err, "should fail") }) t.Run("AddNet", func(t *testing.T) { router, err := NewRouter(&RouterConfig{ CIDR: "1.2.3.0/24", LoggerFactory: loggerFactory, }) if !assert.Nil(t, err, "should succeed") { return } nic, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } err = router.AddNet(nic) assert.Nil(t, err, "should succeed") // Now, eth0 must have one address assigned eth0, err := nic.getInterface("eth0") assert.Nil(t, err, "should succeed") addrs, err := eth0.Addrs() assert.Nil(t, err, "should succeed") assert.Equal(t, 1, len(addrs), "should match") assert.Equal(t, "ip+net", addrs[0].Network(), "should match") assert.Equal(t, "1.2.3.1/24", addrs[0].String(), "should match") assert.Equal(t, "1.2.3.1", addrs[0].(*net.IPNet).IP.String(), "should match") //nolint:forcetypeassert }) t.Run("AddChildRouter", func(t *testing.T) { r1, err := NewRouter(&RouterConfig{ CIDR: "0.0.0.0/0", LoggerFactory: loggerFactory, }) if !assert.Nil(t, err, "should succeed") { return } r2, err := NewRouter(&RouterConfig{ CIDR: "192.168.0.0/24", StaticIPs: []string{ "192.168.0.1", }, LoggerFactory: loggerFactory, }) if !assert.Nil(t, err, "should succeed") { return } err = r1.AddNet(r2) assert.Nil(t, err, "should succeed") err = r1.AddChildRouter(r2) assert.Nil(t, err, "should succeed") }) t.Run("routing", func(t *testing.T) { var nCbs0 int32 doneCh := make(chan struct{}) router, err := NewRouter(&RouterConfig{ CIDR: "1.2.3.0/24", LoggerFactory: loggerFactory, }) if !assert.Nil(t, err, "should succeed") { return } nic := make([]*dummyNIC, 2) ip := make([]*net.UDPAddr, 2) for i := 0; i < 2; i++ { anic, netErr := NewNet(&NetConfig{}) if !assert.NoError(t, netErr, "should succeed") { return } nic[i] = &dummyNIC{ Net: anic, } err2 := router.AddNet(nic[i]) assert.Nil(t, err2, "should succeed") // Now, eth0 must have one address assigned eth0, err2 := nic[i].getInterface("eth0") assert.Nil(t, err2, "should succeed") addrs, err2 := eth0.Addrs() assert.Nil(t, err2, "should succeed") assert.Equal(t, 1, len(addrs), "should match") //nolint:forcetypeassert ip[i] = &net.UDPAddr{ IP: addrs[0].(*net.IPNet).IP, Port: 1111 * (i + 1), } } nic[0].onInboundChunkHandler = func(c Chunk) { log.Debugf("nic[0] received: %s", c.String()) atomic.AddInt32(&nCbs0, 1) } nic[1].onInboundChunkHandler = func(c Chunk) { log.Debugf("nic[1] received: %s", c.String()) close(doneCh) } err = router.Start() assert.Nil(t, err, "should succeed") c := newChunkUDP(ip[0], ip[1]) router.push(c) <-doneCh err = router.Stop() assert.Nil(t, err, "should succeed") assert.Equal(t, int32(0), atomic.LoadInt32(&nCbs0), "should be zero") }) t.Run("AddChunkFilter", func(t *testing.T) { var nCbs0 int32 var nCbs1 int32 router, err := NewRouter(&RouterConfig{ CIDR: "1.2.3.0/24", LoggerFactory: loggerFactory, }) if !assert.Nil(t, err, "should succeed") { return } nic := make([]*dummyNIC, 2) ip := make([]*net.UDPAddr, 2) for i := 0; i < 2; i++ { anic, netErr := NewNet(&NetConfig{}) if !assert.NoError(t, netErr, "should succeed") { return } nic[i] = &dummyNIC{ Net: anic, } err2 := router.AddNet(nic[i]) assert.Nil(t, err2, "should succeed") // Now, eth0 must have one address assigned eth0, err2 := nic[i].getInterface("eth0") assert.Nil(t, err2, "should succeed") addrs, err2 := eth0.Addrs() assert.Nil(t, err2, "should succeed") assert.Equal(t, 1, len(addrs), "should match") //nolint:forcetypeassert ip[i] = &net.UDPAddr{ IP: addrs[0].(*net.IPNet).IP, Port: 1111 * (i + 1), } } nic[0].onInboundChunkHandler = func(c Chunk) { log.Debugf("nic[0] received: %s", c.String()) atomic.AddInt32(&nCbs0, 1) } var seq byte nic[1].onInboundChunkHandler = func(c Chunk) { log.Debugf("nic[1] received: %s", c.String()) seq = c.UserData()[0] atomic.AddInt32(&nCbs1, 1) } // this creates a filter that block the first chunk makeFilter := func(name string) func(c Chunk) bool { n := 0 return func(c Chunk) bool { pass := (n > 0) if pass { log.Debugf("%s passed %s", name, c.String()) } else { log.Debugf("%s blocked %s", name, c.String()) } n++ return pass } } // filter 1: block first one router.AddChunkFilter(makeFilter("filter1")) // filter 2: block first one router.AddChunkFilter(makeFilter("filter2")) err = router.Start() assert.Nil(t, err, "should succeed") // send 3 packets for i := 0; i < 3; i++ { c := newChunkUDP(ip[0], ip[1]) c.userData = make([]byte, 1) c.userData[0] = byte(i) // 1-byte seq num router.push(c) } time.Sleep(50 * time.Millisecond) err = router.Stop() assert.Nil(t, err, "should succeed") assert.Equal(t, int32(0), atomic.LoadInt32(&nCbs0), "should be zero") assert.Equal(t, int32(1), atomic.LoadInt32(&nCbs1), "should be zero") assert.Equal(t, byte(2), seq, "should be the last chunk") }) } func TestRouterDelay(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") subTest := func(t *testing.T, title string, minDelay, maxJitter time.Duration) { t.Helper() t.Run(title, func(t *testing.T) { const margin = 8 * time.Millisecond var nCBs int32 doneCh := make(chan struct{}) router, err := NewRouter(&RouterConfig{ CIDR: "1.2.3.0/24", MinDelay: minDelay, MaxJitter: maxJitter, LoggerFactory: loggerFactory, }) if !assert.Nil(t, err, "should succeed") { return } nic := make([]*dummyNIC, 2) ip := make([]*net.UDPAddr, 2) for i := 0; i < 2; i++ { anic, netErr := NewNet(&NetConfig{}) if !assert.NoError(t, netErr, "should succeed") { return } nic[i] = &dummyNIC{ Net: anic, } err2 := router.AddNet(nic[i]) assert.Nil(t, err2, "should succeed") // Now, eth0 must have one address assigned eth0, err2 := nic[i].getInterface("eth0") assert.Nil(t, err2, "should succeed") addrs, err2 := eth0.Addrs() assert.Nil(t, err2, "should succeed") assert.Equal(t, 1, len(addrs), "should match") //nolint:forcetypeassert ip[i] = &net.UDPAddr{ IP: addrs[0].(*net.IPNet).IP, Port: 1111 * (i + 1), } } var delayRes []time.Duration nPkts := 1 nic[0].onInboundChunkHandler = func(Chunk) {} nic[1].onInboundChunkHandler = func(c Chunk) { delay := time.Since(c.getTimestamp()) delayRes = append(delayRes, delay) n := atomic.AddInt32(&nCBs, 1) if n == int32(nPkts) { //nolint:gosec // nPkts is a constant close(doneCh) } } err = router.Start() assert.Nil(t, err, "should succeed") for i := 0; i < nPkts; i++ { c := newChunkUDP(ip[0], ip[1]) router.push(c) time.Sleep(50 * time.Millisecond) } <-doneCh err = router.Stop() assert.Nil(t, err, "should succeed") // Validate the amount of delays for _, d := range delayRes { log.Infof("min delay : %v", minDelay) log.Infof("max jitter: %v", maxJitter) log.Infof("actual delay: %v", d) assert.True(t, d >= minDelay, "should delay >= 20ms") assert.True(t, d <= (minDelay+maxJitter+margin), "should delay <= minDelay + maxJitter") // Note: actual delay should be within 30ms but giving a 8ms // margin for possible extra delay // (e.g. wakeup delay, debug logs, etc) } }) } subTest(t, "Delay only", 20*time.Millisecond, 0) subTest(t, "Jitter only", 0, 10*time.Millisecond) subTest(t, "Delay and Jitter", 20*time.Millisecond, 10*time.Millisecond) } func TestRouterOneChild(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") t.Run("lan to wan", func(t *testing.T) { doneCh := make(chan struct{}) // WAN wan, err := NewRouter(&RouterConfig{ CIDR: "1.2.3.0/24", LoggerFactory: loggerFactory, }) if !assert.Nil(t, err, "should succeed") { return } nw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } wanNet := &dummyNIC{ Net: nw, } err = wan.AddNet(wanNet) assert.Nil(t, err, "should succeed") // Now, eth0 must have one address assigned wanIP, err := getIPAddr(wanNet) assert.Nil(t, err, "should succeed") log.Debugf("wanIP: %s", wanIP) // LAN lan, err := NewRouter(&RouterConfig{ CIDR: "192.168.0.0/24", LoggerFactory: loggerFactory, }) if !assert.Nil(t, err, "should succeed") { return } lnw, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } lanNet := &dummyNIC{ Net: lnw, } err = lan.AddNet(lanNet) assert.Nil(t, err, "should succeed") // Now, eth0 must have one address assigned lanIP, err := getIPAddr(lanNet) assert.Nil(t, err, "should succeed") log.Debugf("lanIP: %s", lanIP) err = wan.AddRouter(lan) if !assert.Nil(t, err, "should succeed") { return } lanNet.onInboundChunkHandler = func(c Chunk) { log.Debugf("lanNet received: %s", c.String()) close(doneCh) } wanNet.onInboundChunkHandler = func(c Chunk) { log.Debugf("wanNet received: %s", c.String()) // echo the chunk echo := c.Clone().(*chunkUDP) //nolint:forcetypeassert err = echo.setSourceAddr(c.(*chunkUDP).DestinationAddr().String()) //nolint:forcetypeassert assert.NoError(t, err, "should succeed") err = echo.setDestinationAddr(c.(*chunkUDP).SourceAddr().String()) //nolint:forcetypeassert assert.NoError(t, err, "should succeed") log.Debug("wan.push being called..") wan.push(echo) log.Debug("wan.push called!") } err = wan.Start() assert.Nil(t, err, "should succeed") chunk := newChunkUDP( &net.UDPAddr{ IP: net.ParseIP(lanIP), Port: 1234, }, &net.UDPAddr{ IP: net.ParseIP(wanIP), Port: 5678, }, ) log.Debugf("sending %s", chunk.String()) lan.push(chunk) <-doneCh err = wan.Stop() assert.Nil(t, err, "should succeed") }) } func TestRouterStaticIPs(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() // log := loggerFactory.NewLogger("test") t.Run("more than one static IP", func(t *testing.T) { lan, err := NewRouter(&RouterConfig{ CIDR: "192.168.0.0/24", StaticIPs: []string{ "1.2.3.1", "1.2.3.2", "1.2.3.3", }, LoggerFactory: loggerFactory, }) if !assert.Nil(t, err, "should succeed") { return } assert.Equal(t, 3, len(lan.staticIPs), "should be 3") assert.Equal(t, "1.2.3.1", lan.staticIPs[0].String(), "should match") assert.Equal(t, "1.2.3.2", lan.staticIPs[1].String(), "should match") assert.Equal(t, "1.2.3.3", lan.staticIPs[2].String(), "should match") }) t.Run("StaticIPs and StaticIP in the mix", func(t *testing.T) { lan, err := NewRouter(&RouterConfig{ CIDR: "192.168.0.0/24", StaticIPs: []string{ "1.2.3.1", "1.2.3.2", "1.2.3.3", }, StaticIP: demoIP, LoggerFactory: loggerFactory, }) if !assert.Nil(t, err, "should succeed") { return } assert.Equal(t, 4, len(lan.staticIPs), "should be 4") assert.Equal(t, "1.2.3.1", lan.staticIPs[0].String(), "should match") assert.Equal(t, "1.2.3.2", lan.staticIPs[1].String(), "should match") assert.Equal(t, "1.2.3.3", lan.staticIPs[2].String(), "should match") assert.Equal(t, demoIP, lan.staticIPs[3].String(), "should match") }) t.Run("Static IP and local IP mapping", func(t *testing.T) { lan, err := NewRouter(&RouterConfig{ CIDR: "192.168.0.0/24", StaticIPs: []string{ "1.2.3.1/192.168.0.1", "1.2.3.2/192.168.0.2", "1.2.3.3/192.168.0.3", }, LoggerFactory: loggerFactory, }) assert.NoError(t, err, "should succeed") assert.Equal(t, 3, len(lan.staticIPs), "should be 3") assert.Equal(t, "1.2.3.1", lan.staticIPs[0].String(), "should match") assert.Equal(t, "1.2.3.2", lan.staticIPs[1].String(), "should match") assert.Equal(t, "1.2.3.3", lan.staticIPs[2].String(), "should match") assert.Equal(t, 3, len(lan.staticLocalIPs), "should be 3") localIPs := []string{"192.168.0.1", "192.168.0.2", "192.168.0.3"} for i, extIPStr := range []string{"1.2.3.1", "1.2.3.2", "1.2.3.3"} { locIP, ok := lan.staticLocalIPs[extIPStr] assert.True(t, ok, "should have the external IP") assert.Equal(t, localIPs[i], locIP.String(), "should match") } // bad local IP _, err = NewRouter(&RouterConfig{ CIDR: "192.168.0.0/24", StaticIPs: []string{ "1.2.3.1/192.168.0.1", "1.2.3.2/bad", // <-- invalid local IP }, LoggerFactory: loggerFactory, }) assert.Error(t, err, "should fail") // local IP out of CIDR _, err = NewRouter(&RouterConfig{ CIDR: "192.168.0.0/24", StaticIPs: []string{ "1.2.3.1/192.168.0.1", "1.2.3.2/172.16.1.2", // <-- out of CIDR }, LoggerFactory: loggerFactory, }) assert.Error(t, err, "should fail") // num of local IPs mismatch _, err = NewRouter(&RouterConfig{ CIDR: "192.168.0.0/24", StaticIPs: []string{ "1.2.3.1/192.168.0.1", "1.2.3.2", // <-- lack of local IP }, LoggerFactory: loggerFactory, }) assert.Error(t, err, "should fail") }) t.Run("1:1 NAT configuration", func(t *testing.T) { wan, err := NewRouter(&RouterConfig{ CIDR: "0.0.0.0/0", LoggerFactory: loggerFactory, }) assert.NoError(t, err, "should succeed") lan, err := NewRouter(&RouterConfig{ CIDR: "192.168.0.0/24", StaticIPs: []string{ "1.2.3.1/192.168.0.1", "1.2.3.2/192.168.0.2", "1.2.3.3/192.168.0.3", }, NATType: &NATType{ Mode: NATModeNAT1To1, }, LoggerFactory: loggerFactory, }) assert.NoError(t, err, "should succeed") err = wan.AddRouter(lan) if !assert.NoError(t, err, "should succeed") { return } if !assert.NotNil(t, lan.nat, "should not be nil") { return } assert.Equal(t, 3, len(lan.nat.mappedIPs), "should match") assert.Equal(t, "1.2.3.1", lan.nat.mappedIPs[0].String(), "should match") assert.Equal(t, "1.2.3.2", lan.nat.mappedIPs[1].String(), "should match") assert.Equal(t, "1.2.3.3", lan.nat.mappedIPs[2].String(), "should match") assert.Equal(t, 3, len(lan.nat.localIPs), "should match") assert.Equal(t, "192.168.0.1", lan.nat.localIPs[0].String(), "should match") assert.Equal(t, "192.168.0.2", lan.nat.localIPs[1].String(), "should match") assert.Equal(t, "192.168.0.3", lan.nat.localIPs[2].String(), "should match") }) } func TestRouterFailures(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() // log := loggerFactory.NewLogger("test") t.Run("Stop when router is stopped", func(t *testing.T) { r, err := NewRouter(&RouterConfig{ CIDR: "1.2.3.0/24", LoggerFactory: loggerFactory, }) if !assert.Nil(t, err, "should succeed") { return } err = r.Stop() assert.Error(t, err, "should fail") }) t.Run("AddNet", func(t *testing.T) { router, err := NewRouter(&RouterConfig{ CIDR: "1.2.3.0/24", LoggerFactory: loggerFactory, }) if !assert.Nil(t, err, "should succeed") { return } nic, err := NewNet(&NetConfig{ StaticIPs: []string{ "5.6.7.8", // out of parent router'c CIDR }, }) if !assert.NoError(t, err, "should succeed") { return } err = router.AddNet(nic) assert.Error(t, err, "should fail") }) t.Run("AddRouter", func(t *testing.T) { r1, err := NewRouter(&RouterConfig{ CIDR: "1.2.3.0/24", LoggerFactory: loggerFactory, }) if !assert.Nil(t, err, "should succeed") { return } r2, err := NewRouter(&RouterConfig{ CIDR: "192.168.0.0/24", StaticIPs: []string{ "5.6.7.8", // out of parent router'c CIDR }, LoggerFactory: loggerFactory, }) if !assert.Nil(t, err, "should succeed") { return } err = r1.AddRouter(r2) assert.Error(t, err, "should fail") }) t.Run("AddChildRouterWithoutAddNet", func(t *testing.T) { r1, err := NewRouter(&RouterConfig{ CIDR: "1.2.3.0/24", LoggerFactory: loggerFactory, }) if !assert.Nil(t, err, "should succeed") { return } r2, err := NewRouter(&RouterConfig{ CIDR: "192.168.0.0/24", StaticIPs: []string{ "5.6.7.8", // out of parent router'c CIDR }, LoggerFactory: loggerFactory, }) if !assert.Nil(t, err, "should succeed") { return } err = r1.AddChildRouter(r2) assert.Error(t, err, "should fail") }) } golang-github-pion-transport-v3-3.0.8/vnet/stress_test.go000066400000000000000000000101111507057301700234600ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "fmt" "net" "sync" "testing" "time" "github.com/pion/logging" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/assert" ) func TestStressTestUDP(t *testing.T) { //nolint:cyclop loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") t.Run("lan to wan", func(t *testing.T) { tt := test.TimeOut(30 * time.Second) defer tt.Stop() // WAN with a nic (net0) wan, err := NewRouter(&RouterConfig{ CIDR: "1.2.3.0/24", QueueSize: 1000, LoggerFactory: loggerFactory, }) assert.NoError(t, err, "should succeed") net0, err := NewNet(&NetConfig{ StaticIPs: []string{demoIP}, }) if !assert.NoError(t, err, "should succeed") { return } err = wan.AddNet(net0) assert.NoError(t, err, "should succeed") // LAN with a nic (net1) lan, err := NewRouter(&RouterConfig{ CIDR: "192.168.0.0/24", QueueSize: 1000, LoggerFactory: loggerFactory, }) assert.NoError(t, err, "should succeed") net1, err := NewNet(&NetConfig{}) if !assert.NoError(t, err, "should succeed") { return } err = lan.AddNet(net1) assert.NoError(t, err, "should succeed") err = wan.AddRouter(lan) assert.NoError(t, err, "should succeed") err = wan.Start() assert.NoError(t, err, "should succeed") defer func() { err = wan.Stop() assert.NoError(t, err, "should succeed") }() // Find IP address for net0 ifs, err := net0.Interfaces() if !assert.NoError(t, err, "should succeed") { return } log.Debugf("num ifs: %d", len(ifs)) var echoServerIP net.IP loop: for _, ifc := range ifs { log.Debugf("flags: %v", ifc.Flags) if ifc.Flags&net.FlagUp == 0 { continue } if ifc.Flags&net.FlagLoopback != 0 { continue } addrs, err2 := ifc.Addrs() if !assert.NoError(t, err2, "should succeed") { return } log.Debugf("num addrs: %d", len(addrs)) for _, addr := range addrs { log.Debugf("addr: %s", addr.String()) switch addr := addr.(type) { case *net.IPNet: echoServerIP = addr.IP break loop case *net.IPAddr: echoServerIP = addr.IP break loop } } } if !assert.NotNil(t, echoServerIP, "should have IP address") { return } log.Debugf("echo server IP: %s", echoServerIP.String()) // Set up an echo server on WAN conn0, err := net0.ListenPacket( "udp4", fmt.Sprintf("%s:0", echoServerIP)) if !assert.NoError(t, err, "should succeed") { return } doneCh0 := make(chan struct{}) go func() { buf := make([]byte, 1500) for { n, from, err2 := conn0.ReadFrom(buf) if err2 != nil { break } // echo back _, err2 = conn0.WriteTo(buf[:n], from) if err2 != nil { break } } close(doneCh0) }() var wg sync.WaitGroup runEchoTest := func() { // Set up a client var numRecvd int const numToSend int = 400 const pktSize int = 1200 conn1, err2 := net0.ListenPacket("udp4", "0.0.0.0:0") if !assert.NoError(t, err2, "should succeed") { return } doneCh1 := make(chan struct{}) go func() { buf := make([]byte, 1500) for { n, _, err3 := conn1.ReadFrom(buf) if err3 != nil { break } if n != pktSize { break } numRecvd++ } close(doneCh1) }() buf := make([]byte, pktSize) to := conn0.LocalAddr() for i := 0; i < numToSend; i++ { _, err3 := conn1.WriteTo(buf, to) assert.NoError(t, err3, "should succeed") time.Sleep(10 * time.Millisecond) } time.Sleep(time.Second) err2 = conn1.Close() assert.NoError(t, err2, "should succeed") <-doneCh1 // allow some packet loss assert.True(t, numRecvd >= numToSend*8/10, "majority should received") if numRecvd < numToSend { log.Infof("lost %d packets", numToSend-numRecvd) } wg.Done() } // Run echo tests concurrently for i := 0; i < 20; i++ { wg.Add(1) go runEchoTest() } wg.Wait() err = conn0.Close() assert.NoError(t, err, "should succeed") }) } golang-github-pion-transport-v3-3.0.8/vnet/tbf.go000066400000000000000000000103041507057301700216550ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "math" "sync" "time" "github.com/pion/logging" ) const ( // Bit is a single bit. Bit = 1 // KBit is a kilobit. KBit = 1000 * Bit // MBit is a Megabit. MBit = 1000 * KBit ) // TokenBucketFilter implements a token bucket rate limit algorithm. type TokenBucketFilter struct { NIC currentTokensInBucket float64 c chan Chunk queue *chunkQueue queueSize int // in bytes mutex sync.Mutex rate int maxBurst int minRefillDuration time.Duration wg sync.WaitGroup done chan struct{} log logging.LeveledLogger } // TBFOption is the option type to configure a TokenBucketFilter. type TBFOption func(*TokenBucketFilter) TBFOption // TBFQueueSizeInBytes sets the max number of bytes waiting in the queue. Can // only be set in constructor before using the TBF. func TBFQueueSizeInBytes(bytes int) TBFOption { return func(t *TokenBucketFilter) TBFOption { prev := t.queueSize t.queueSize = bytes return TBFQueueSizeInBytes(prev) } } // TBFRate sets the bit rate of a TokenBucketFilter. func TBFRate(rate int) TBFOption { return func(t *TokenBucketFilter) TBFOption { t.mutex.Lock() defer t.mutex.Unlock() previous := t.rate t.rate = rate return TBFRate(previous) } } // TBFMaxBurst sets the bucket size of the token bucket filter. This is the // maximum size that can instantly leave the filter, if the bucket is full. func TBFMaxBurst(size int) TBFOption { return func(t *TokenBucketFilter) TBFOption { t.mutex.Lock() defer t.mutex.Unlock() previous := t.maxBurst t.maxBurst = size return TBFMaxBurst(previous) } } // Set updates a setting on the token bucket filter. func (t *TokenBucketFilter) Set(opts ...TBFOption) (previous TBFOption) { for _, opt := range opts { previous = opt(t) } return previous } // NewTokenBucketFilter creates and starts a new TokenBucketFilter. func NewTokenBucketFilter(n NIC, opts ...TBFOption) (*TokenBucketFilter, error) { tbf := &TokenBucketFilter{ NIC: n, currentTokensInBucket: 0, c: make(chan Chunk), queue: nil, queueSize: 50000, mutex: sync.Mutex{}, rate: 1 * MBit, maxBurst: 8 * KBit, minRefillDuration: 100 * time.Millisecond, wg: sync.WaitGroup{}, done: make(chan struct{}), log: logging.NewDefaultLoggerFactory().NewLogger("tbf"), } tbf.Set(opts...) tbf.queue = newChunkQueue(0, tbf.queueSize) tbf.wg.Add(1) go tbf.run() return tbf, nil } func (t *TokenBucketFilter) onInboundChunk(c Chunk) { select { case t.c <- c: case <-t.done: } } func (t *TokenBucketFilter) run() { defer t.wg.Done() t.refillTokens(t.minRefillDuration) lastRefill := time.Now() for { select { case <-t.done: t.drainQueue() return case chunk := <-t.c: if time.Since(lastRefill) > t.minRefillDuration { t.refillTokens(time.Since(lastRefill)) lastRefill = time.Now() } t.queue.push(chunk) t.drainQueue() } } } func (t *TokenBucketFilter) refillTokens(dt time.Duration) { t.mutex.Lock() defer t.mutex.Unlock() m := 1000.0 / float64(dt.Milliseconds()) add := (float64(t.rate) / m) / 8.0 t.currentTokensInBucket = math.Min(float64(t.maxBurst), t.currentTokensInBucket+add) t.log.Tracef( "add=(%v / %v) / 8 = %v, currentTokensInBucket=%v, maxBurst=%v", t.rate, m, add, t.currentTokensInBucket, t.maxBurst, ) } func (t *TokenBucketFilter) drainQueue() { for { next := t.queue.peek() if next == nil { break } tokens := float64(len(next.UserData())) if t.currentTokensInBucket < tokens { t.log.Tracef("currentTokensInBucket=%v, tokens=%v, stop drain", t.currentTokensInBucket, tokens) break } t.log.Tracef("currentTokensInBucket=%v, tokens=%v, pop chunk", t.currentTokensInBucket, tokens) t.queue.pop() t.NIC.onInboundChunk(next) t.currentTokensInBucket -= tokens } } // Close closes and stops the token bucket filter queue. func (t *TokenBucketFilter) Close() error { close(t.done) t.wg.Wait() return nil } golang-github-pion-transport-v3-3.0.8/vnet/tbf_test.go000066400000000000000000000075141507057301700227250ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !wasm // +build !wasm package vnet import ( "context" "runtime" "sync" "testing" "time" "github.com/pion/logging" "github.com/stretchr/testify/assert" ) func TestTokenBucketFilter(t *testing.T) { t.Run("bitrateBelowCapacity", func(t *testing.T) { mnic := newMockNIC(t) tbf, err := NewTokenBucketFilter(mnic, TBFRate(10*MBit), TBFMaxBurst(10*MBit)) assert.NoError(t, err, "should succeed") received := 0 mnic.mockOnInboundChunk = func(Chunk) { received++ } time.Sleep(1 * time.Second) sent := 100 for i := 0; i < sent; i++ { tbf.onInboundChunk(&chunkUDP{ userData: make([]byte, 1200), }) } assert.NoError(t, tbf.Close()) assert.Equal(t, sent, received) }) subTest := func(t *testing.T, capacity int, maxBurst int, duration time.Duration) { t.Helper() log := logging.NewDefaultLoggerFactory().NewLogger("test") mnic := newMockNIC(t) tbf, err := NewTokenBucketFilter(mnic, TBFRate(capacity), TBFMaxBurst(maxBurst)) assert.NoError(t, err, "should succeed") chunkChan := make(chan Chunk) mnic.mockOnInboundChunk = func(c Chunk) { chunkChan <- c } var wg sync.WaitGroup wg.Add(1) ctx, cancel := context.WithCancel(context.Background()) go func() { defer wg.Done() totalBytesReceived := 0 totalPacketsReceived := 0 bytesReceived := 0 packetsReceived := 0 start := time.Now() last := time.Now() ticker := time.NewTicker(time.Second) defer ticker.Stop() for { select { case <-ctx.Done(): bits := float64(totalBytesReceived) * 8.0 rate := bits / time.Since(start).Seconds() mBitPerSecond := rate / float64(MBit) // Allow 5% more than capacity due to max bursts assert.Less(t, rate, 1.05*float64(capacity)) assert.Greater(t, rate, 0.9*float64(capacity)) log.Infof( "duration=%v, bytesReceived=%v, packetsReceived=%v throughput=%.2f Mb/s", time.Since(start), bytesReceived, packetsReceived, mBitPerSecond, ) return case now := <-ticker.C: delta := now.Sub(last) last = now bits := float64(bytesReceived) * 8.0 rate := bits / delta.Seconds() mBitPerSecond := rate / float64(MBit) log.Infof( "duration=%v, bytesReceived=%v, packetsReceived=%v throughput=%.2f Mb/s", delta, bytesReceived, packetsReceived, mBitPerSecond, ) // Allow 10% more than capacity due to max bursts assert.Less(t, rate, 1.10*float64(capacity)) assert.Greater(t, rate, 0.9*float64(capacity)) bytesReceived = 0 packetsReceived = 0 case c := <-chunkChan: bytesReceived += len(c.UserData()) packetsReceived++ totalBytesReceived += len(c.UserData()) totalPacketsReceived++ } } }() go func() { defer cancel() bytesSent := 0 packetsSent := 0 var start time.Time for start = time.Now(); time.Since(start) < duration; { c := &chunkUDP{ userData: make([]byte, 1200), } tbf.onInboundChunk(c) bytesSent += len(c.UserData()) packetsSent++ runtime.Gosched() } bits := float64(bytesSent) * 8.0 rate := bits / time.Since(start).Seconds() mBitPerSecond := rate / float64(MBit) log.Infof( "duration=%v, bytesSent=%v, packetsSent=%v throughput=%.2f Mb/s", time.Since(start), bytesSent, packetsSent, mBitPerSecond, ) assert.NoError(t, tbf.Close()) }() wg.Wait() } t.Run("500Kbit-s", func(t *testing.T) { subTest(t, 500*KBit, 10*KBit, 10*time.Second) }) t.Run("1Mbit-s", func(t *testing.T) { subTest(t, 1*MBit, 20*KBit, 10*time.Second) }) t.Run("2Mbit-s", func(t *testing.T) { subTest(t, 2*MBit, 40*KBit, 10*time.Second) }) t.Run("8Mbit-s", func(t *testing.T) { subTest(t, 8*MBit, 160*KBit, 10*time.Second) }) } golang-github-pion-transport-v3-3.0.8/vnet/udpproxy.go000066400000000000000000000140401507057301700227750ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "context" "net" "sync" "time" ) // UDPProxy is a proxy between real server(net.UDPConn) and vnet.UDPConn. // // High level design: // // .............................................. // : Virtual Network (vnet) : // : : // +-------+ * 1 +----+ +--------+ : // | :App |------------>|:Net|--o<-----|:Router | ............................. // +-------+ +----+ | | : UDPProxy : // : | | +----+ +---------+ +---------+ +--------+ // : | |--->o--|:Net|-->o-| vnet. |-->o-| net. |--->-| :Real | // : | | +----+ | UDPConn | | UDPConn | | Server | // : | | : +---------+ +---------+ +--------+ // : | | ............................: // : +--------+ : // ............................................... type UDPProxy struct { // The router bind to. router *Router // Each vnet source, bind to a real socket to server. // key is real server addr, which is net.Addr // value is *aUDPProxyWorker workers sync.Map // For each endpoint, we never know when to start and stop proxy, // so we stop the endpoint when timeout. timeout time.Duration // For utest, to mock the target real server. // Optional, use the address of received client packet. mockRealServerAddr *net.UDPAddr } // NewProxy create a proxy, the router for this proxy belongs/bind to. If need to proxy for // please create a new proxy for each router. For all addresses we proxy, we will create a // vnet.Net in this router and proxy all packets. func NewProxy(router *Router) (*UDPProxy, error) { v := &UDPProxy{router: router, timeout: 2 * time.Minute} return v, nil } // Close the proxy, stop all workers. func (v *UDPProxy) Close() error { v.workers.Range(func(_, value any) bool { _ = value.(*aUDPProxyWorker).Close() //nolint:forcetypeassert return true }) return nil } // Proxy starts a worker for server, ignore if already started. func (v *UDPProxy) Proxy(client *Net, server *net.UDPAddr) error { // Note that even if the worker exists, it's also ok to create a same worker, // because the router will use the last one, and the real server will see a address // change event after we switch to the next worker. if _, ok := v.workers.Load(server.String()); ok { // nolint:godox // TODO: Need to restart the stopped worker? return nil } // Not exists, create a new one. worker := &aUDPProxyWorker{ router: v.router, mockRealServerAddr: v.mockRealServerAddr, } // Create context for cleanup. var ctx context.Context ctx, worker.ctxDisposeCancel = context.WithCancel(context.Background()) v.workers.Store(server.String(), worker) return worker.Proxy(ctx, client, server) } // A proxy worker for a specified proxy server. type aUDPProxyWorker struct { router *Router mockRealServerAddr *net.UDPAddr // Each vnet source, bind to a real socket to server. // key is vnet client addr, which is net.Addr // value is *net.UDPConn endpoints sync.Map // For cleanup. ctxDisposeCancel context.CancelFunc wg sync.WaitGroup } func (v *aUDPProxyWorker) Close() error { // Notify all goroutines to dispose. v.ctxDisposeCancel() // Wait for all goroutines quit. v.wg.Wait() return nil } func (v *aUDPProxyWorker) Proxy(ctx context.Context, _ *Net, serverAddr *net.UDPAddr) error { // nolint:gocognit,cyclop // Create vnet for real server by serverAddr. nw, err := NewNet(&NetConfig{ StaticIP: serverAddr.IP.String(), }) if err != nil { return err } if err = v.router.AddNet(nw); err != nil { return err } // We must create a "same" vnet.UDPConn as the net.UDPConn, // which has the same ip:port, to copy packets between them. vnetSocket, err := nw.ListenUDP("udp4", serverAddr) if err != nil { return err } // User stop proxy, we should close the socket. go func() { <-ctx.Done() _ = vnetSocket.Close() }() // Got new vnet client, start a new endpoint. findEndpointBy := func(addr net.Addr) (*net.UDPConn, error) { // Exists binding. if value, ok := v.endpoints.Load(addr.String()); ok { // Exists endpoint, reuse it. return value.(*net.UDPConn), nil //nolint:forcetypeassert } // The real server we proxy to, for utest to mock it. realAddr := serverAddr if v.mockRealServerAddr != nil { realAddr = v.mockRealServerAddr } // Got new vnet client, create new endpoint. realSocket, err := net.DialUDP("udp4", nil, realAddr) if err != nil { return nil, err } // User stop proxy, we should close the socket. go func() { <-ctx.Done() _ = realSocket.Close() }() // Bind address. v.endpoints.Store(addr.String(), realSocket) // Got packet from real serverAddr, we should proxy it to vnet. v.wg.Add(1) go func(vnetClientAddr net.Addr) { defer v.wg.Done() buf := make([]byte, 1500) for { n, _, err := realSocket.ReadFrom(buf) if err != nil { return } if n <= 0 { continue // Drop packet } if _, err := vnetSocket.WriteTo(buf[:n], vnetClientAddr); err != nil { return } } }(addr) return realSocket, nil } // Start a proxy goroutine. v.wg.Add(1) go func() { defer v.wg.Done() buf := make([]byte, 1500) for { n, addr, err := vnetSocket.ReadFrom(buf) if err != nil { return } if n <= 0 || addr == nil { continue // Drop packet } realSocket, err := findEndpointBy(addr) if err != nil { continue // Drop packet. } if _, err := realSocket.Write(buf[:n]); err != nil { return } } }() return nil } golang-github-pion-transport-v3-3.0.8/vnet/udpproxy_direct.go000066400000000000000000000024331507057301700243320ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package vnet import ( "fmt" "net" ) // Deliver directly send packet to vnet or real-server. // For example, we can use this API to simulate the REPLAY ATTACK. func (v *UDPProxy) Deliver(sourceAddr, destAddr net.Addr, b []byte) (nn int, err error) { v.workers.Range(func(_, value any) bool { worker, ok := value.(*aUDPProxyWorker) if !ok { return false } if nn, err = worker.Deliver(sourceAddr, destAddr, b); err != nil { return false // Fail, abort. } else if nn == len(b) { return false // Done. } return true // Deliver by next worker. }) return } func (v *aUDPProxyWorker) Deliver(sourceAddr, _ net.Addr, b []byte) (nn int, err error) { addr, ok := sourceAddr.(*net.UDPAddr) if !ok { return 0, fmt.Errorf("invalid addr %v", sourceAddr) // nolint:err113 } // nolint:godox // TODO: Support deliver packet from real server to vnet. // If packet is from vnet, proxy to real server. var realSocket *net.UDPConn value, ok := v.endpoints.Load(addr.String()) if !ok { return 0, nil } realSocket = value.(*net.UDPConn) // nolint:forcetypeassert // Send to real server. if _, err := realSocket.Write(b); err != nil { return 0, err } return len(b), nil } golang-github-pion-transport-v3-3.0.8/vnet/udpproxy_direct_test.go000066400000000000000000000206151507057301700253730ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !wasm // +build !wasm package vnet import ( "context" "errors" "fmt" "net" "sync" "testing" "time" "github.com/pion/logging" "github.com/stretchr/testify/assert" ) // The vnet client: // // 10.0.0.11:5787 // // which proxy to real server: // // 192.168.1.10:8000 // // We should get a reply if directly deliver to proxy. func TestUDPProxyDirectDeliverTypical(t *testing.T) { //nolint:cyclop ctx, cancel := context.WithCancel(context.Background()) var r0, r1, r2 error defer func() { assert.NoErrorf(t, r0, "fail for ctx=%v, r0=%v", ctx.Err(), r0) assert.NoErrorf(t, r1, "fail for ctx=%v, r1=%v", ctx.Err(), r1) assert.NoErrorf(t, r2, "fail for ctx=%v, r2=%v", ctx.Err(), r2) }() var wg sync.WaitGroup defer wg.Wait() // Timeout, fail wg.Add(1) go func() { defer wg.Done() defer cancel() select { case <-ctx.Done(): case <-time.After(time.Duration(*testTimeout) * time.Millisecond): r2 = fmt.Errorf("timeout") // nolint:err113 } }() // For utest, we always proxy vnet packets to the random port we listen to. mockServer := NewMockUDPEchoServer() wg.Add(1) go func() { defer wg.Done() defer cancel() if err := mockServer.doMockUDPServer(ctx); err != nil { r0 = err } }() // Create a vent and proxy. wg.Add(1) go func() { defer wg.Done() defer cancel() // When real server is ready, start the vnet test. select { case <-ctx.Done(): return case <-mockServer.realServerReady.Done(): } doVnetProxy := func() error { router, err := NewRouter(&RouterConfig{ CIDR: "0.0.0.0/0", LoggerFactory: logging.NewDefaultLoggerFactory(), }) if err != nil { return err } clientNetwork, err := NewNet(&NetConfig{ StaticIP: "10.0.0.11", }) if err != nil { return err } if err = router.AddNet(clientNetwork); err != nil { return err } if err = router.Start(); err != nil { return err } defer router.Stop() // nolint:errcheck proxy, err := NewProxy(router) if err != nil { return err } defer proxy.Close() // nolint:errcheck // For utest, mock the target real server. proxy.mockRealServerAddr = mockServer.realServerAddr // The real server address to proxy to. // Note that for utest, we will proxy to a local address. serverAddr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8000") if err != nil { return err } if err = proxy.Proxy(clientNetwork, serverAddr); err != nil { //nolint:contextcheck return err } // Now, all packets from client, will be proxy to real server, vice versa. client, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5787") if err != nil { return err } // When system quit, interrupt client. selfKill, selfKillCancel := context.WithCancel(context.Background()) go func() { <-ctx.Done() selfKillCancel() _ = client.Close() }() // Write by vnet client. if _, err = client.WriteTo([]byte("Hello"), serverAddr); err != nil { return err } buf := make([]byte, 1500) if n, addr, err := client.ReadFrom(buf); err != nil { // nolint:gocritic,govet if errors.Is(selfKill.Err(), context.Canceled) { return nil } return err } else if n != 5 || addr == nil { return fmt.Errorf("n=%v, addr=%v", n, addr) // nolint:err113 } else if string(buf[:n]) != "Hello" { // nolint:goconst return fmt.Errorf("data %v", buf[:n]) // nolint:err113 } // Directly write, simulate the ARQ packet. // We should got the echo packet also. if _, err = proxy.Deliver(client.LocalAddr(), serverAddr, []byte("Hello")); err != nil { return err } if n, addr, err := client.ReadFrom(buf); err != nil { // nolint:gocritic,govet if errors.Is(selfKill.Err(), context.Canceled) { return nil } return err } else if n != 5 || addr == nil { return fmt.Errorf("n=%v, addr=%v", n, addr) // nolint:err113 } else if string(buf[:n]) != "Hello" { return fmt.Errorf("data %v", buf[:n]) // nolint:err113 } return err } if err := doVnetProxy(); err != nil { r1 = err } }() } // Error if deliver to invalid address. func TestUDPProxyDirectDeliverBadCase(t *testing.T) { //nolint:cyclop ctx, cancel := context.WithCancel(context.Background()) var r0, r1, r2 error defer func() { assert.NoErrorf(t, r0, "fail for ctx=%v, r0=%v", ctx.Err(), r0) assert.NoErrorf(t, r1, "fail for ctx=%v, r1=%v", ctx.Err(), r1) assert.NoErrorf(t, r2, "fail for ctx=%v, r2=%v", ctx.Err(), r2) }() var wg sync.WaitGroup defer wg.Wait() // Timeout, fail wg.Add(1) go func() { defer wg.Done() defer cancel() select { case <-ctx.Done(): case <-time.After(time.Duration(*testTimeout) * time.Millisecond): r2 = fmt.Errorf("timeout") // nolint:err113 } }() // For utest, we always proxy vnet packets to the random port we listen to. mockServer := NewMockUDPEchoServer() wg.Add(1) go func() { defer wg.Done() defer cancel() if err := mockServer.doMockUDPServer(ctx); err != nil { r0 = err } }() // Create a vent and proxy. wg.Add(1) go func() { defer wg.Done() defer cancel() // When real server is ready, start the vnet test. select { case <-ctx.Done(): return case <-mockServer.realServerReady.Done(): } doVnetProxy := func() error { router, err := NewRouter(&RouterConfig{ CIDR: "0.0.0.0/0", LoggerFactory: logging.NewDefaultLoggerFactory(), }) if err != nil { return err } clientNetwork, err := NewNet(&NetConfig{ StaticIP: "10.0.0.11", }) if err != nil { return err } if err = router.AddNet(clientNetwork); err != nil { return err } if err = router.Start(); err != nil { return err } defer router.Stop() // nolint:errcheck proxy, err := NewProxy(router) if err != nil { return err } defer proxy.Close() // nolint:errcheck // For utest, mock the target real server. proxy.mockRealServerAddr = mockServer.realServerAddr // The real server address to proxy to. // Note that for utest, we will proxy to a local address. serverAddr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8000") if err != nil { return err } if err = proxy.Proxy(clientNetwork, serverAddr); err != nil { //nolint:contextcheck return err } // Now, all packets from client, will be proxy to real server, vice versa. client, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5787") if err != nil { return err } // When system quit, interrupt client. selfKill, selfKillCancel := context.WithCancel(context.Background()) go func() { <-ctx.Done() selfKillCancel() _ = client.Close() }() // Write by vnet client. if _, err = client.WriteTo([]byte("Hello"), serverAddr); err != nil { return err } buf := make([]byte, 1500) if n, addr, err := client.ReadFrom(buf); err != nil { // nolint:gocritic,govet if errors.Is(selfKill.Err(), context.Canceled) { return nil } return err } else if n != 5 || addr == nil { return fmt.Errorf("n=%v, addr=%v", n, addr) // nolint:err113 } else if string(buf[:n]) != "Hello" { // nolint:goconst return fmt.Errorf("data %v", buf[:n]) // nolint:err113 } // BadCase: Invalid address, error and ignore. tcpAddr, err := net.ResolveTCPAddr("tcp4", "192.168.1.10:8000") if err != nil { return err } if _, err = proxy.Deliver(tcpAddr, serverAddr, []byte("Hello")); err == nil { return fmt.Errorf("should err") // nolint:err113 } // BadCase: Invalid target address, ignore. udpAddr, err := net.ResolveUDPAddr("udp4", "10.0.0.12:5788") if err != nil { return err } if nn, err := proxy.Deliver(udpAddr, serverAddr, []byte("Hello")); err != nil { // nolint:govet return err } else if nn != 0 { return fmt.Errorf("invalid %v", nn) // nolint:err113 } // BadCase: Write on closed socket, error and ignore. proxy.workers.Range(func(_, value any) bool { //nolint:forcetypeassert value.(*aUDPProxyWorker).endpoints.Range(func(_, value any) bool { _ = value.(*net.UDPConn).Close() //nolint:forcetypeassert return true }) return true }) if _, err = proxy.Deliver(client.LocalAddr(), serverAddr, []byte("Hello")); err == nil { return fmt.Errorf("should error") // nolint:err113 } return nil } if err := doVnetProxy(); err != nil { r1 = err } }() } golang-github-pion-transport-v3-3.0.8/vnet/udpproxy_test.go000066400000000000000000000355271507057301700240510ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !wasm // +build !wasm package vnet import ( "context" "errors" "flag" "fmt" "net" "os" "sync" "testing" "time" "github.com/pion/logging" "github.com/stretchr/testify/assert" ) type MockUDPEchoServer struct { realServerAddr *net.UDPAddr realServerReady context.Context //nolint:containedctx // this is a test context... realServerReadyCancel context.CancelFunc } func NewMockUDPEchoServer() *MockUDPEchoServer { v := &MockUDPEchoServer{} v.realServerReady, v.realServerReadyCancel = context.WithCancel(context.Background()) return v } func (v *MockUDPEchoServer) doMockUDPServer(ctx context.Context) error { //nolint:cyclop // Listen to a random port. laddr, err := net.ResolveUDPAddr("udp4", "127.0.0.1:0") if err != nil { return err } conn, err := net.ListenUDP("udp4", laddr) if err != nil { return err } v.realServerAddr = conn.LocalAddr().(*net.UDPAddr) //nolint:forcetypeassert v.realServerReadyCancel() // When system quit, interrupt client. selfKill, selfKillCancel := context.WithCancel(context.Background()) go func() { <-ctx.Done() selfKillCancel() _ = conn.Close() }() // Note that if they has the same ID, the address should not changed. addrs := make(map[string]net.Addr) // Start an echo UDP server. buf := make([]byte, 1500) for ctx.Err() == nil { n, addr, err := conn.ReadFrom(buf) if err != nil { if errors.Is(selfKill.Err(), context.Canceled) { return nil } return err } else if n == 0 || addr == nil { return fmt.Errorf("n=%v, addr=%v", n, addr) // nolint:err113 } else if nn, err := conn.WriteTo(buf[:n], addr); err != nil { return err } else if nn != n { return fmt.Errorf("nn=%v, n=%v", nn, n) // nolint:err113 } // Check the address, should not change, use content as ID. clientID := string(buf[:n]) if oldAddr, ok := addrs[clientID]; ok && oldAddr.String() != addr.String() { return fmt.Errorf("address change %v to %v", oldAddr.String(), addr.String()) // nolint:err113 } addrs[clientID] = addr } return nil } var testTimeout = flag.Int("timeout", 5000, "For each case, the timeout in ms") // nolint:gochecknoglobals func TestMain(m *testing.M) { flag.Parse() os.Exit(m.Run()) //nolint:forbidigo } // vnet client: // // 10.0.0.11:5787 // // proxy to real server: // // 192.168.1.10:8000 // // . func TestUDPProxyOne2One(t *testing.T) { //nolint:gocyclo,cyclop ctx, cancel := context.WithCancel(context.Background()) var r0, r1, r2 error defer func() { assert.NoErrorf(t, r0, "fail for ctx=%v, r0=%v", ctx.Err(), r0) assert.NoErrorf(t, r1, "fail for ctx=%v, r1=%v", ctx.Err(), r1) assert.NoErrorf(t, r2, "fail for ctx=%v, r2=%v", ctx.Err(), r2) }() var wg sync.WaitGroup defer wg.Wait() // Timeout, fail wg.Add(1) go func() { defer wg.Done() defer cancel() select { case <-ctx.Done(): case <-time.After(time.Duration(*testTimeout) * time.Millisecond): r2 = fmt.Errorf("timeout") // nolint:err113 } }() // For utest, we always proxy vnet packets to the random port we listen to. mockServer := NewMockUDPEchoServer() wg.Add(1) go func() { defer wg.Done() defer cancel() if err := mockServer.doMockUDPServer(ctx); err != nil { r0 = err } }() // Create a vent and proxy. wg.Add(1) go func() { defer wg.Done() defer cancel() // When real server is ready, start the vnet test. select { case <-ctx.Done(): return case <-mockServer.realServerReady.Done(): } doVnetProxy := func() error { router, err := NewRouter(&RouterConfig{ CIDR: "0.0.0.0/0", LoggerFactory: logging.NewDefaultLoggerFactory(), }) if err != nil { return err } clientNetwork, err := NewNet(&NetConfig{ StaticIP: "10.0.0.11", }) if err != nil { return err } if err = router.AddNet(clientNetwork); err != nil { return err } if err = router.Start(); err != nil { return err } defer router.Stop() // nolint:errcheck proxy, err := NewProxy(router) if err != nil { return err } defer proxy.Close() // nolint:errcheck // For utest, mock the target real server. proxy.mockRealServerAddr = mockServer.realServerAddr // The real server address to proxy to. // Note that for utest, we will proxy to a local address. serverAddr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8000") if err != nil { return err } if err = proxy.Proxy(clientNetwork, serverAddr); err != nil { //nolint:contextcheck return err } // Now, all packets from client, will be proxy to real server, vice versa. client, err := clientNetwork.ListenPacket("udp4", "10.0.0.11:5787") if err != nil { return err } // When system quit, interrupt client. selfKill, selfKillCancel := context.WithCancel(context.Background()) go func() { <-ctx.Done() selfKillCancel() _ = client.Close() // nolint:errcheck }() for i := 0; i < 10; i++ { if _, err = client.WriteTo([]byte("Hello"), serverAddr); err != nil { return err } var n int var addr net.Addr buf := make([]byte, 1500) if n, addr, err = client.ReadFrom(buf); err != nil { // nolint:gocritic if errors.Is(selfKill.Err(), context.Canceled) { return nil } return err } else if n != 5 || addr == nil { return fmt.Errorf("n=%v, addr=%v", n, addr) // nolint:err113 } else if string(buf[:n]) != "Hello" { return fmt.Errorf("data %v", buf[:n]) // nolint:err113 } // Wait for awhile for each UDP packet, to simulate real network. select { case <-ctx.Done(): return nil case <-time.After(30 * time.Millisecond): } } return err } if err := doVnetProxy(); err != nil { r1 = err } }() } // vnet client: // // 10.0.0.11:5787 // 10.0.0.11:5788 // // proxy to real server: // // 192.168.1.10:8000 func TestUDPProxyTwo2One(t *testing.T) { //nolint:gocyclo,cyclop ctx, cancel := context.WithCancel(context.Background()) var r0, r1, r2, r3 error defer func() { assert.NoErrorf(t, r0, "fail for ctx=%v, r0=%v", ctx.Err(), r0) assert.NoErrorf(t, r1, "fail for ctx=%v, r1=%v", ctx.Err(), r1) assert.NoErrorf(t, r2, "fail for ctx=%v, r2=%v", ctx.Err(), r2) assert.NoErrorf(t, r3, "fail for ctx=%v, r3=%v", ctx.Err(), r3) }() var wg sync.WaitGroup defer wg.Wait() // Timeout, fail wg.Add(1) go func() { defer wg.Done() defer cancel() select { case <-ctx.Done(): case <-time.After(time.Duration(*testTimeout) * time.Millisecond): r2 = fmt.Errorf("timeout") // nolint:err113 } }() // For utest, we always proxy vnet packets to the random port we listen to. mockServer := NewMockUDPEchoServer() wg.Add(1) go func() { defer wg.Done() defer cancel() if err := mockServer.doMockUDPServer(ctx); err != nil { r0 = err } }() // Create a vent and proxy. wg.Add(1) go func() { defer wg.Done() defer cancel() // When real server is ready, start the vnet test. select { case <-ctx.Done(): return case <-mockServer.realServerReady.Done(): } doVnetProxy := func() error { router, err := NewRouter(&RouterConfig{ CIDR: "0.0.0.0/0", LoggerFactory: logging.NewDefaultLoggerFactory(), }) if err != nil { return err } clientNetwork, err := NewNet(&NetConfig{ StaticIP: "10.0.0.11", }) if err != nil { return err } if err = router.AddNet(clientNetwork); err != nil { return err } if err = router.Start(); err != nil { return err } defer router.Stop() // nolint:errcheck proxy, err := NewProxy(router) if err != nil { return err } defer proxy.Close() // nolint:errcheck // For utest, mock the target real server. proxy.mockRealServerAddr = mockServer.realServerAddr // The real server address to proxy to. // Note that for utest, we will proxy to a local address. serverAddr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8000") if err != nil { return err } if err = proxy.Proxy(clientNetwork, serverAddr); err != nil { //nolint:contextcheck return err } handClient := func(address, echoData string) error { // Now, all packets from client, will be proxy to real server, vice versa. client, err := clientNetwork.ListenPacket("udp4", address) // nolint:govet if err != nil { return err } // When system quit, interrupt client. selfKill, selfKillCancel := context.WithCancel(context.Background()) go func() { <-ctx.Done() selfKillCancel() _ = client.Close() }() for i := 0; i < 10; i++ { if _, err := client.WriteTo([]byte(echoData), serverAddr); err != nil { // nolint:govet return err } var n int var addr net.Addr buf := make([]byte, 1400) if n, addr, err = client.ReadFrom(buf); err != nil { // nolint:gocritic if errors.Is(selfKill.Err(), context.Canceled) { return nil } return err } else if n != len(echoData) || addr == nil { return fmt.Errorf("n=%v, addr=%v", n, addr) // nolint:err113 } else if string(buf[:n]) != echoData { return fmt.Errorf("check data %v", buf[:n]) // nolint:err113 } // Wait for awhile for each UDP packet, to simulate real network. select { case <-ctx.Done(): return nil case <-time.After(30 * time.Millisecond): } } return nil } client0, client0Cancel := context.WithCancel(context.Background()) go func() { defer client0Cancel() address := "10.0.0.11:5787" if handClientErr := handClient(address, "Hello"); handClientErr != nil { r3 = fmt.Errorf("client %v err %w", address, handClientErr) } }() client1, client1Cancel := context.WithCancel(context.Background()) go func() { defer client1Cancel() address := "10.0.0.11:5788" if handClientErr := handClient(address, "World"); handClientErr != nil { r3 = fmt.Errorf("client %v err %w", address, handClientErr) } }() select { case <-ctx.Done(): case <-client0.Done(): case <-client1.Done(): } return err } if err := doVnetProxy(); err != nil { r1 = err } }() } // vnet client: // // 10.0.0.11:5787 // // proxy to real server: // // 192.168.1.10:8000 // // vnet client: // // 10.0.0.11:5788 // // proxy to real server: // // 192.168.1.10:8000 func TestUDPProxyProxyTwice(t *testing.T) { //nolint:gocyclo,cyclop ctx, cancel := context.WithCancel(context.Background()) var r0, r1, r2, r3 error defer func() { assert.NoErrorf(t, r0, "fail for ctx=%v, r0=%v", ctx.Err(), r0) assert.NoErrorf(t, r1, "fail for ctx=%v, r1=%v", ctx.Err(), r1) assert.NoErrorf(t, r2, "fail for ctx=%v, r2=%v", ctx.Err(), r2) assert.NoErrorf(t, r3, "fail for ctx=%v, r3=%v", ctx.Err(), r3) }() var wg sync.WaitGroup defer wg.Wait() // Timeout, fail wg.Add(1) go func() { defer wg.Done() defer cancel() select { case <-ctx.Done(): case <-time.After(time.Duration(*testTimeout) * time.Millisecond): r2 = fmt.Errorf("timeout") // nolint:err113 } }() // For utest, we always proxy vnet packets to the random port we listen to. mockServer := NewMockUDPEchoServer() wg.Add(1) go func() { defer wg.Done() defer cancel() if err := mockServer.doMockUDPServer(ctx); err != nil { r0 = err } }() // Create a vent and proxy. wg.Add(1) go func() { defer wg.Done() defer cancel() // When real server is ready, start the vnet test. select { case <-ctx.Done(): return case <-mockServer.realServerReady.Done(): } doVnetProxy := func() error { router, err := NewRouter(&RouterConfig{ CIDR: "0.0.0.0/0", LoggerFactory: logging.NewDefaultLoggerFactory(), }) if err != nil { return err } clientNetwork, err := NewNet(&NetConfig{ StaticIP: "10.0.0.11", }) if err != nil { return err } if err = router.AddNet(clientNetwork); err != nil { return err } if err = router.Start(); err != nil { return err } defer router.Stop() // nolint:errcheck proxy, err := NewProxy(router) if err != nil { return err } defer proxy.Close() // nolint:errcheck // For utest, mock the target real server. proxy.mockRealServerAddr = mockServer.realServerAddr // The real server address to proxy to. // Note that for utest, we will proxy to a local address. serverAddr, err := net.ResolveUDPAddr("udp4", "192.168.1.10:8000") if err != nil { return err } handClient := func(address, echoData string) error { // We proxy multiple times, for example, in publisher and player, both call // the proxy when got answer. if handClientErr := proxy.Proxy(clientNetwork, serverAddr); handClientErr != nil { //nolint:contextcheck return handClientErr } // Now, all packets from client, will be proxy to real server, vice versa. client, handClientErr := clientNetwork.ListenPacket("udp4", address) // nolint:govet if handClientErr != nil { return handClientErr } // When system quit, interrupt client. selfKill, selfKillCancel := context.WithCancel(context.Background()) go func() { <-ctx.Done() selfKillCancel() _ = client.Close() // nolint:errcheck }() for i := 0; i < 10; i++ { if _, handClientErr = client.WriteTo([]byte(echoData), serverAddr); handClientErr != nil { return handClientErr } buf := make([]byte, 1500) if n, addr, handClientErr := client.ReadFrom(buf); handClientErr != nil { // nolint:gocritic,govet if errors.Is(selfKill.Err(), context.Canceled) { return nil } return handClientErr } else if n != len(echoData) || addr == nil { return fmt.Errorf("n=%v, addr=%v", n, addr) // nolint:err113 } else if string(buf[:n]) != echoData { return fmt.Errorf("verify data %v", buf[:n]) // nolint:err113 } // Wait for awhile for each UDP packet, to simulate real network. select { case <-ctx.Done(): return nil case <-time.After(30 * time.Millisecond): } } return nil } client0, client0Cancel := context.WithCancel(context.Background()) go func() { defer client0Cancel() address := "10.0.0.11:5787" if err = handClient(address, "Hello"); err != nil { r3 = fmt.Errorf("client %v err %w", address, err) } }() client1, client1Cancel := context.WithCancel(context.Background()) go func() { defer client1Cancel() // Slower than client0, 60ms. // To simulate the real player or publisher, might not start at the same time. select { case <-ctx.Done(): return case <-time.After(150 * time.Millisecond): } address := "10.0.0.11:5788" if err = handClient(address, "World"); err != nil { r3 = fmt.Errorf("client %v err %w", address, err) // nolint:err113 } }() select { case <-ctx.Done(): case <-client0.Done(): case <-client1.Done(): } return err } if err := doVnetProxy(); err != nil { r1 = err } }() } golang-github-pion-transport-v3-3.0.8/vnet/vnet.go000066400000000000000000000002551507057301700220620ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package vnet provides a virtual network layer for pion package vnet