pax_global_header00006660000000000000000000000064150705746030014521gustar00rootroot0000000000000052 comment=b855a5eebc2ae03835354d803bedc5983813a815 golang-github-pion-dtls-v3-3.0.7/000077500000000000000000000000001507057460300165145ustar00rootroot00000000000000golang-github-pion-dtls-v3-3.0.7/.editorconfig000066400000000000000000000006321507057460300211720ustar00rootroot00000000000000# http://editorconfig.org/ # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT root = true [*] charset = utf-8 insert_final_newline = true trim_trailing_whitespace = true end_of_line = lf [*.go] indent_style = tab indent_size = 4 [{*.yml,*.yaml}] indent_style = space indent_size = 2 # Makefiles always use tabs for indentation [Makefile] indent_style = tab golang-github-pion-dtls-v3-3.0.7/.gitignore000066400000000000000000000006321507057460300205050ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT ### JetBrains IDE ### ##################### .idea/ ### Emacs Temporary Files ### ############################# *~ ### Folders ### ############### bin/ vendor/ node_modules/ ### Files ### ############# *.ivf *.ogg tags cover.out *.sw[poe] *.wasm examples/sfu-ws/cert.pem examples/sfu-ws/key.pem wasm_exec.js golang-github-pion-dtls-v3-3.0.7/.golangci.yml000066400000000000000000000205631507057460300211060ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT run: timeout: 5m linters-settings: govet: enable: - shadow misspell: locale: US exhaustive: default-signifies-exhaustive: true gomodguard: blocked: modules: - github.com/pkg/errors: recommendations: - errors forbidigo: analyze-types: true forbid: - ^fmt.Print(f|ln)?$ - ^log.(Panic|Fatal|Print)(f|ln)?$ - ^os.Exit$ - ^panic$ - ^print(ln)?$ - p: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ pkg: ^testing$ msg: "use testify/assert instead" varnamelen: max-distance: 12 min-name-length: 2 ignore-type-assert-ok: true ignore-map-index-ok: true ignore-chan-recv-ok: true ignore-decls: - i int - n int - w io.Writer - r io.Reader - b []byte revive: rules: # Prefer 'any' type alias over 'interface{}' for Go 1.18+ compatibility - name: use-any severity: warning disabled: false linters: enable: - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - bidichk # Checks for dangerous unicode character sequences - bodyclose # checks whether HTTP response body is closed successfully - containedctx # containedctx is a linter that detects struct contained context.Context field - contextcheck # check the function whether use a non-inherited context - cyclop # checks function and package cyclomatic complexity - decorder # check declaration order and count of types, constants, variables and functions - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - dupl # Tool for code clone detection - durationcheck # check for two durations multiplied together - err113 # Golang linter to check the errors handling expressions - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - exhaustive # check exhaustiveness of enum switch statements - exportloopref # checks for pointers to enclosing loop variables - forbidigo # Forbids identifiers - forcetypeassert # finds forced type assertions - gci # Gci control golang package import order and make it always deterministic. - gochecknoglobals # Checks that no globals are present in Go code - gocognit # Computes and checks the cognitive complexity of functions - goconst # Finds repeated strings that could be replaced by a constant - gocritic # The most opinionated Go source code linter - gocyclo # Computes and checks the cyclomatic complexity of functions - godot # Check if comments end in a period - godox # Tool for detection of FIXME, TODO and other comment keywords - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - gofumpt # Gofumpt checks whether code was gofumpt-ed. - goheader # Checks is file header matches to pattern - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - goprintffuncname # Checks that printf-like functions are named with `f` at the end - gosec # Inspects source code for security problems - gosimple # Linter for Go source code that specializes in simplifying a code - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - grouper # An analyzer to analyze expression groups. - importas # Enforces consistent import aliases - ineffassign # Detects when assignments to existing variables are not used - lll # Reports long lines - maintidx # maintidx measures the maintainability index of each function. - makezero # Finds slice declarations with non-zero initial length - misspell # Finds commonly misspelled English words in comments - nakedret # Finds naked returns in functions greater than a specified function length - nestif # Reports deeply nested if statements - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity - noctx # noctx finds sending http request without context.Context - predeclared # find code that shadows one of Go's predeclared identifiers - revive # golint replacement, finds style mistakes - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks - stylecheck # Stylecheck is a replacement for golint - tagliatelle # Checks the struct tags. - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - unconvert # Remove unnecessary type conversions - unparam # Reports unused function parameters - unused # Checks Go code for unused constants, variables, functions and types - varnamelen # checks that the length of a variable's name matches its scope - wastedassign # wastedassign finds wasted assignment statements - whitespace # Tool for detection of leading and trailing whitespace disable: - depguard # Go linter that checks if package imports are in a list of acceptable packages - funlen # Tool for detection of long functions - gochecknoinits # Checks that no init functions are present in Go code - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - interfacebloat # A linter that checks length of interface. - ireturn # Accept Interfaces, Return Concrete Types - mnd # An analyzer to detect magic numbers - nolintlint # Reports ill-formed or insufficient nolint directives - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test - prealloc # Finds slice declarations that could potentially be preallocated - promlinter # Check Prometheus metrics naming via promlint - rowserrcheck # checks whether Err of rows is checked successfully - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - testpackage # linter that makes you use a separate _test package - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - wrapcheck # Checks that errors returned from external packages are wrapped - wsl # Whitespace Linter - Forces you to use empty lines! issues: exclude-use-default: false exclude-dirs-use-default: false exclude-rules: # Allow complex tests and examples, better to be self contained - path: (examples|main\.go) linters: - gocognit - forbidigo - path: _test\.go linters: - gocognit # Allow forbidden identifiers in CLI commands - path: cmd linters: - forbidigo golang-github-pion-dtls-v3-3.0.7/.goreleaser.yml000066400000000000000000000001711507057460300214440ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT builds: - skip: true golang-github-pion-dtls-v3-3.0.7/.reuse/000077500000000000000000000000001507057460300177155ustar00rootroot00000000000000golang-github-pion-dtls-v3-3.0.7/.reuse/dep5000066400000000000000000000011141507057460300204720ustar00rootroot00000000000000Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ Upstream-Name: Pion Source: https://github.com/pion/ Files: README.md DESIGN.md **/README.md AUTHORS.txt renovate.json go.mod go.sum **/go.mod **/go.sum .eslintrc.json package.json examples.json sfu-ws/flutter/.gitignore sfu-ws/flutter/pubspec.yaml c-data-channels/webrtc.h examples/examples.json yarn.lock Copyright: 2023 The Pion community License: MIT Files: testdata/seed/* testdata/fuzz/* **/testdata/fuzz/* api/*.txt Copyright: 2023 The Pion community License: CC0-1.0 golang-github-pion-dtls-v3-3.0.7/LICENSE000066400000000000000000000021051507057460300175170ustar00rootroot00000000000000MIT License Copyright (c) 2023 The Pion community Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. golang-github-pion-dtls-v3-3.0.7/LICENSES/000077500000000000000000000000001507057460300177215ustar00rootroot00000000000000golang-github-pion-dtls-v3-3.0.7/LICENSES/CC0-1.0.txt000066400000000000000000000156101507057460300213260ustar00rootroot00000000000000Creative Commons Legal Code CC0 1.0 Universal CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED HEREUNDER. Statement of Purpose The laws of most jurisdictions throughout the world automatically confer exclusive Copyright and Related Rights (defined below) upon the creator and subsequent owner(s) (each and all, an "owner") of an original work of authorship and/or a database (each, a "Work"). Certain owners wish to permanently relinquish those rights to a Work for the purpose of contributing to a commons of creative, cultural and scientific works ("Commons") that the public can reliably and without fear of later claims of infringement build upon, modify, incorporate in other works, reuse and redistribute as freely as possible in any form whatsoever and for any purposes, including without limitation commercial purposes. These owners may contribute to the Commons to promote the ideal of a free culture and the further production of creative, cultural and scientific works, or to gain reputation or greater distribution for their Work in part through the use and efforts of others. For these and/or other purposes and motivations, and without any expectation of additional consideration or compensation, the person associating CC0 with a Work (the "Affirmer"), to the extent that he or she is an owner of Copyright and Related Rights in the Work, voluntarily elects to apply CC0 to the Work and publicly distribute the Work under its terms, with knowledge of his or her Copyright and Related Rights in the Work and the meaning and intended legal effect of CC0 on those rights. 1. Copyright and Related Rights. A Work made available under CC0 may be protected by copyright and related or neighboring rights ("Copyright and Related Rights"). Copyright and Related Rights include, but are not limited to, the following: i. the right to reproduce, adapt, distribute, perform, display, communicate, and translate a Work; ii. moral rights retained by the original author(s) and/or performer(s); iii. publicity and privacy rights pertaining to a person's image or likeness depicted in a Work; iv. rights protecting against unfair competition in regards to a Work, subject to the limitations in paragraph 4(a), below; v. rights protecting the extraction, dissemination, use and reuse of data in a Work; vi. database rights (such as those arising under Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, and under any national implementation thereof, including any amended or successor version of such directive); and vii. other similar, equivalent or corresponding rights throughout the world based on applicable law or treaty, and any national implementations thereof. 2. Waiver. To the greatest extent permitted by, but not in contravention of, applicable law, Affirmer hereby overtly, fully, permanently, irrevocably and unconditionally waives, abandons, and surrenders all of Affirmer's Copyright and Related Rights and associated claims and causes of action, whether now known or unknown (including existing as well as future claims and causes of action), in the Work (i) in all territories worldwide, (ii) for the maximum duration provided by applicable law or treaty (including future time extensions), (iii) in any current or future medium and for any number of copies, and (iv) for any purpose whatsoever, including without limitation commercial, advertising or promotional purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each member of the public at large and to the detriment of Affirmer's heirs and successors, fully intending that such Waiver shall not be subject to revocation, rescission, cancellation, termination, or any other legal or equitable action to disrupt the quiet enjoyment of the Work by the public as contemplated by Affirmer's express Statement of Purpose. 3. Public License Fallback. Should any part of the Waiver for any reason be judged legally invalid or ineffective under applicable law, then the Waiver shall be preserved to the maximum extent permitted taking into account Affirmer's express Statement of Purpose. In addition, to the extent the Waiver is so judged Affirmer hereby grants to each affected person a royalty-free, non transferable, non sublicensable, non exclusive, irrevocable and unconditional license to exercise Affirmer's Copyright and Related Rights in the Work (i) in all territories worldwide, (ii) for the maximum duration provided by applicable law or treaty (including future time extensions), (iii) in any current or future medium and for any number of copies, and (iv) for any purpose whatsoever, including without limitation commercial, advertising or promotional purposes (the "License"). The License shall be deemed effective as of the date CC0 was applied by Affirmer to the Work. Should any part of the License for any reason be judged legally invalid or ineffective under applicable law, such partial invalidity or ineffectiveness shall not invalidate the remainder of the License, and in such case Affirmer hereby affirms that he or she will not (i) exercise any of his or her remaining Copyright and Related Rights in the Work or (ii) assert any associated claims and causes of action with respect to the Work, in either case contrary to Affirmer's express Statement of Purpose. 4. Limitations and Disclaimers. a. No trademark or patent rights held by Affirmer are waived, abandoned, surrendered, licensed or otherwise affected by this document. b. Affirmer offers the Work as-is and makes no representations or warranties of any kind concerning the Work, express, implied, statutory or otherwise, including without limitation warranties of title, merchantability, fitness for a particular purpose, non infringement, or the absence of latent or other defects, accuracy, or the present or absence of errors, whether or not discoverable, all to the greatest extent permissible under applicable law. c. Affirmer disclaims responsibility for clearing rights of other persons that may apply to the Work or any use thereof, including without limitation any person's Copyright and Related Rights in the Work. Further, Affirmer disclaims responsibility for obtaining any necessary consents, permissions or other rights required for any use of the Work. d. Affirmer understands and acknowledges that Creative Commons is not a party to this document and has no duty or obligation with respect to this CC0 or use of the Work. golang-github-pion-dtls-v3-3.0.7/LICENSES/MIT.txt000066400000000000000000000020661507057460300211170ustar00rootroot00000000000000MIT License Copyright (c) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. golang-github-pion-dtls-v3-3.0.7/README.md000066400000000000000000000136751507057460300200070ustar00rootroot00000000000000


Pion DTLS

A Go implementation of DTLS

Pion DTLS Sourcegraph Widget join us on Discord Follow us on Bluesky
GitHub Workflow Status Go Reference Coverage Status Go Report Card License: MIT


Native [DTLS 1.2][rfc6347] implementation in the Go programming language. A long term goal is a professional security review, and maybe an inclusion in stdlib. ### RFCs #### Implemented - **RFC 6347**: [Datagram Transport Layer Security Version 1.2][rfc6347] - **RFC 5705**: [Keying Material Exporters for Transport Layer Security (TLS)][rfc5705] - **RFC 7627**: [Transport Layer Security (TLS) - Session Hash and Extended Master Secret Extension][rfc7627] - **RFC 7301**: [Transport Layer Security (TLS) - Application-Layer Protocol Negotiation Extension][rfc7301] [rfc5289]: https://tools.ietf.org/html/rfc5289 [rfc5487]: https://tools.ietf.org/html/rfc5487 [rfc5489]: https://tools.ietf.org/html/rfc5489 [rfc5705]: https://tools.ietf.org/html/rfc5705 [rfc6347]: https://tools.ietf.org/html/rfc6347 [rfc6655]: https://tools.ietf.org/html/rfc6655 [rfc7301]: https://tools.ietf.org/html/rfc7301 [rfc7627]: https://tools.ietf.org/html/rfc7627 [rfc8422]: https://tools.ietf.org/html/rfc8422 ### Goals/Progress This will only be targeting DTLS 1.2, and the most modern/common cipher suites. We would love contributions that fall under the 'Planned Features' and any bug fixes! #### Current features * DTLS 1.2 Client/Server * Key Exchange via ECDHE(curve25519, nistp256, nistp384) and PSK * Packet loss and re-ordering is handled during handshaking * Key export ([RFC 5705][rfc5705]) * Serialization and Resumption of sessions * Extended Master Secret extension ([RFC 7627][rfc7627]) * ALPN extension ([RFC 7301][rfc7301]) #### Supported ciphers ##### ECDHE * TLS_ECDHE_ECDSA_WITH_AES_128_CCM ([RFC 6655][rfc6655]) * TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 ([RFC 6655][rfc6655]) * TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 ([RFC 5289][rfc5289]) * TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 ([RFC 5289][rfc5289]) * TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 ([RFC 5289][rfc5289]) * TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 ([RFC 5289][rfc5289]) * TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA ([RFC 8422][rfc8422]) * TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA ([RFC 8422][rfc8422]) ##### PSK * TLS_PSK_WITH_AES_128_CCM ([RFC 6655][rfc6655]) * TLS_PSK_WITH_AES_128_CCM_8 ([RFC 6655][rfc6655]) * TLS_PSK_WITH_AES_256_CCM_8 ([RFC 6655][rfc6655]) * TLS_PSK_WITH_AES_128_GCM_SHA256 ([RFC 5487][rfc5487]) * TLS_PSK_WITH_AES_128_CBC_SHA256 ([RFC 5487][rfc5487]) ##### ECDHE & PSK * TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 ([RFC 5489][rfc5489]) #### Planned Features * Chacha20Poly1305 #### Excluded Features * DTLS 1.0 * Renegotiation * Compression ### Using This library needs at least Go 1.13, and you should have [Go modules enabled](https://github.com/golang/go/wiki/Modules). #### Pion DTLS For a DTLS 1.2 Server that listens on 127.0.0.1:4444 ```sh go run examples/listen/selfsign/main.go ``` For a DTLS 1.2 Client that connects to 127.0.0.1:4444 ```sh go run examples/dial/selfsign/main.go ``` #### OpenSSL Pion DTLS can connect to itself and OpenSSL. ``` // Generate a certificate openssl ecparam -out key.pem -name prime256v1 -genkey openssl req -new -sha256 -key key.pem -out server.csr openssl x509 -req -sha256 -days 365 -in server.csr -signkey key.pem -out cert.pem // Use with examples/dial/selfsign/main.go openssl s_server -dtls1_2 -cert cert.pem -key key.pem -accept 4444 // Use with examples/listen/selfsign/main.go openssl s_client -dtls1_2 -connect 127.0.0.1:4444 -debug -cert cert.pem -key key.pem ``` ### Using with PSK Pion DTLS also comes with examples that do key exchange via PSK #### Pion DTLS ```sh go run examples/listen/psk/main.go ``` ```sh go run examples/dial/psk/main.go ``` #### OpenSSL ``` // Use with examples/dial/psk/main.go openssl s_server -dtls1_2 -accept 4444 -nocert -psk abc123 -cipher PSK-AES128-CCM8 // Use with examples/listen/psk/main.go openssl s_client -dtls1_2 -connect 127.0.0.1:4444 -psk abc123 -cipher PSK-AES128-CCM8 ``` ### Community Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt). Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news. We are always looking to support **your projects**. Please reach out if you have something to build! If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) ### Contributing Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible ### License MIT License - see [LICENSE](LICENSE) for full text golang-github-pion-dtls-v3-3.0.7/bench_test.go000066400000000000000000000052231507057460300211630ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "crypto/tls" "fmt" "testing" "time" "github.com/pion/dtls/v3/pkg/crypto/selfsign" dtlsnet "github.com/pion/dtls/v3/pkg/net" "github.com/pion/logging" "github.com/pion/transport/v3/dpipe" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/assert" ) func TestSimpleReadWrite(t *testing.T) { report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() certificate, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) gotHello := make(chan struct{}) go func() { server, sErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{certificate}, LoggerFactory: logging.NewDefaultLoggerFactory(), }, false) assert.NoError(t, sErr) buf := make([]byte, 1024) _, sErr = server.Read(buf) //nolint:contextcheck assert.NoError(t, sErr) gotHello <- struct{}{} assert.NoError(t, server.Close()) //nolint:contextcheck }() client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ LoggerFactory: logging.NewDefaultLoggerFactory(), InsecureSkipVerify: true, }, false) assert.NoError(t, err) _, err = client.Write([]byte("hello")) assert.NoError(t, err) select { case <-gotHello: // OK case <-time.After(time.Second * 5): assert.Fail(t, "timeout") } assert.NoError(t, client.Close()) } func benchmarkConn(b *testing.B, payloadSize int64) { b.Helper() b.Run(fmt.Sprintf("%d", payloadSize), func(b *testing.B) { ctx := context.Background() ca, cb := dpipe.Pipe() certificate, err := selfsign.GenerateSelfSigned() assert.NoError(b, err) server := make(chan *Conn) go func() { s, sErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{certificate}, }, false) assert.NoError(b, sErr) server <- s }() hw := make([]byte, payloadSize) b.ReportAllocs() b.SetBytes(int64(len(hw))) go func() { client, cErr := testClient( ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{InsecureSkipVerify: true}, false, ) assert.NoError(b, cErr) for { _, cErr = client.Write(hw) //nolint:contextcheck assert.NoError(b, cErr) } }() s := <-server buf := make([]byte, 2048) for i := 0; i < b.N; i++ { _, err = s.Read(buf) assert.NoError(b, err) } }) } func BenchmarkConnReadWrite(b *testing.B) { for _, n := range []int64{16, 128, 512, 1024, 2048} { benchmarkConn(b, n) } } golang-github-pion-dtls-v3-3.0.7/certificate.go000066400000000000000000000114471507057460300213340ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "bytes" "crypto/tls" "crypto/x509" "fmt" "strings" "github.com/pion/dtls/v3/pkg/protocol/handshake" ) // ClientHelloInfo contains information from a ClientHello message in order to // guide application logic in the GetCertificate. type ClientHelloInfo struct { // ServerName indicates the name of the server requested by the client // in order to support virtual hosting. ServerName is only set if the // client is using SNI (see RFC 4366, Section 3.1). ServerName string // CipherSuites lists the CipherSuites supported by the client (e.g. // TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256). CipherSuites []CipherSuiteID // RandomBytes stores the client hello random bytes RandomBytes [handshake.RandomBytesLength]byte } // CertificateRequestInfo contains information from a server's // CertificateRequest message, which is used to demand a certificate and proof // of control from a client. type CertificateRequestInfo struct { // AcceptableCAs contains zero or more, DER-encoded, X.501 // Distinguished Names. These are the names of root or intermediate CAs // that the server wishes the returned certificate to be signed by. An // empty slice indicates that the server has no preference. AcceptableCAs [][]byte } // SupportsCertificate returns nil if the provided certificate is supported by // the server that sent the CertificateRequest. Otherwise, it returns an error // describing the reason for the incompatibility. // NOTE: original src: // https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/common.go#L1273 func (cri *CertificateRequestInfo) SupportsCertificate(c *tls.Certificate) error { if len(cri.AcceptableCAs) == 0 { return nil } for j, cert := range c.Certificate { x509Cert := c.Leaf // Parse the certificate if this isn't the leaf node, or if // chain.Leaf was nil. if j != 0 || x509Cert == nil { var err error if x509Cert, err = x509.ParseCertificate(cert); err != nil { return fmt.Errorf("failed to parse certificate #%d in the chain: %w", j, err) } } for _, ca := range cri.AcceptableCAs { if bytes.Equal(x509Cert.RawIssuer, ca) { return nil } } } return errNotAcceptableCertificateChain } func (c *handshakeConfig) setNameToCertificateLocked() { nameToCertificate := make(map[string]*tls.Certificate) for i := range c.localCertificates { cert := &c.localCertificates[i] x509Cert := cert.Leaf if x509Cert == nil { var parseErr error x509Cert, parseErr = x509.ParseCertificate(cert.Certificate[0]) if parseErr != nil { continue } } if len(x509Cert.Subject.CommonName) > 0 { nameToCertificate[strings.ToLower(x509Cert.Subject.CommonName)] = cert } for _, san := range x509Cert.DNSNames { nameToCertificate[strings.ToLower(san)] = cert } } c.nameToCertificate = nameToCertificate } //nolint:cyclop func (c *handshakeConfig) getCertificate(clientHelloInfo *ClientHelloInfo) (*tls.Certificate, error) { c.mu.Lock() defer c.mu.Unlock() if c.localGetCertificate != nil && (len(c.localCertificates) == 0 || len(clientHelloInfo.ServerName) > 0) { cert, err := c.localGetCertificate(clientHelloInfo) if cert != nil || err != nil { return cert, err } } if c.nameToCertificate == nil { c.setNameToCertificateLocked() } if len(c.localCertificates) == 0 { return nil, errNoCertificates } if len(c.localCertificates) == 1 { // There's only one choice, so no point doing any work. return &c.localCertificates[0], nil } if len(clientHelloInfo.ServerName) == 0 { return &c.localCertificates[0], nil } name := strings.TrimRight(strings.ToLower(clientHelloInfo.ServerName), ".") if cert, ok := c.nameToCertificate[name]; ok { return cert, nil } // try replacing labels in the name with wildcards until we get a // match. labels := strings.Split(name, ".") for i := range labels { labels[i] = "*" candidate := strings.Join(labels, ".") if cert, ok := c.nameToCertificate[candidate]; ok { return cert, nil } } // If nothing matches, return the first certificate. return &c.localCertificates[0], nil } // NOTE: original src: // https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/handshake_client.go#L974 func (c *handshakeConfig) getClientCertificate(cri *CertificateRequestInfo) (*tls.Certificate, error) { c.mu.Lock() defer c.mu.Unlock() if c.localGetClientCertificate != nil { return c.localGetClientCertificate(cri) } for i := range c.localCertificates { chain := c.localCertificates[i] if err := cri.SupportsCertificate(&chain); err != nil { continue } return &chain, nil } // No acceptable certificate found. Don't send a certificate. return new(tls.Certificate), nil } golang-github-pion-dtls-v3-3.0.7/certificate_test.go000066400000000000000000000045771507057460300224010ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "crypto/tls" "testing" "github.com/pion/dtls/v3/pkg/crypto/selfsign" "github.com/stretchr/testify/assert" ) func TestGetCertificate(t *testing.T) { certificateWildcard, err := selfsign.GenerateSelfSignedWithDNS("*.test.test") assert.NoError(t, err) certificateTest, err := selfsign.GenerateSelfSignedWithDNS("test.test", "www.test.test", "pop.test.test") assert.NoError(t, err) certificateRandom, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) testCases := []struct { localCertificates []tls.Certificate desc string serverName string expectedCertificate tls.Certificate getCertificate func(info *ClientHelloInfo) (*tls.Certificate, error) }{ { desc: "Simple match in CN", localCertificates: []tls.Certificate{ certificateRandom, certificateTest, certificateWildcard, }, serverName: "test.test", expectedCertificate: certificateTest, }, { desc: "Simple match in SANs", localCertificates: []tls.Certificate{ certificateRandom, certificateTest, certificateWildcard, }, serverName: "www.test.test", expectedCertificate: certificateTest, }, { desc: "Wildcard match", localCertificates: []tls.Certificate{ certificateRandom, certificateTest, certificateWildcard, }, serverName: "foo.test.test", expectedCertificate: certificateWildcard, }, { desc: "No match return first", localCertificates: []tls.Certificate{ certificateRandom, certificateTest, certificateWildcard, }, serverName: "foo.bar", expectedCertificate: certificateRandom, }, { desc: "Get certificate from callback", getCertificate: func(*ClientHelloInfo) (*tls.Certificate, error) { return &certificateTest, nil }, expectedCertificate: certificateTest, }, } for _, test := range testCases { test := test t.Run(test.desc, func(t *testing.T) { cfg := &handshakeConfig{ localCertificates: test.localCertificates, localGetCertificate: test.getCertificate, } cert, err := cfg.getCertificate(&ClientHelloInfo{ServerName: test.serverName}) assert.NoError(t, err) assert.Equal(t, test.expectedCertificate.Leaf, cert.Leaf, "Certificate Leaf should match expected") }) } } golang-github-pion-dtls-v3-3.0.7/cipher_suite.go000066400000000000000000000245551507057460300215410ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" "crypto/tls" "fmt" "hash" "github.com/pion/dtls/v3/internal/ciphersuite" "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) // CipherSuiteID is an ID for our supported CipherSuites. type CipherSuiteID = ciphersuite.ID // Supported Cipher Suites. const ( // AES-128-CCM //nolint:revive,stylecheck TLS_ECDHE_ECDSA_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM //nolint:revive,stylecheck TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 // AES-128-GCM-SHA256 //nolint:revive,stylecheck TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 //nolint:revive,stylecheck TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 //nolint:revive,stylecheck TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 //nolint:revive,stylecheck TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 // AES-256-CBC-SHA //nolint:revive,stylecheck TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA //nolint:revive,stylecheck TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA //nolint:revive,stylecheck TLS_PSK_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM //nolint:revive,stylecheck TLS_PSK_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM_8 //nolint:revive,stylecheck TLS_PSK_WITH_AES_256_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_256_CCM_8 //nolint:revive,stylecheck TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_GCM_SHA256 //nolint:revive,stylecheck TLS_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CBC_SHA256 //nolint:revive,stylecheck TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 ) // CipherSuiteAuthenticationType controls what authentication method is using during the handshake for a CipherSuite. type CipherSuiteAuthenticationType = ciphersuite.AuthenticationType // AuthenticationType Enums. const ( CipherSuiteAuthenticationTypeCertificate CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypeCertificate CipherSuiteAuthenticationTypePreSharedKey CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypePreSharedKey CipherSuiteAuthenticationTypeAnonymous CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypeAnonymous ) // CipherSuiteKeyExchangeAlgorithm controls what exchange algorithm is using during the handshake for a CipherSuite. type CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithm // CipherSuiteKeyExchangeAlgorithm Bitmask. const ( CipherSuiteKeyExchangeAlgorithmNone CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithmNone CipherSuiteKeyExchangeAlgorithmPsk CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithmPsk CipherSuiteKeyExchangeAlgorithmEcdhe CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithmEcdhe ) var _ = allCipherSuites() // Necessary until this function isn't only used by Go 1.14 // CipherSuite is an interface that all DTLS CipherSuites must satisfy. type CipherSuite interface { // String of CipherSuite, only used for logging String() string // ID of CipherSuite. ID() CipherSuiteID // What type of Certificate does this CipherSuite use CertificateType() clientcertificate.Type // What Hash function is used during verification HashFunc() func() hash.Hash // AuthenticationType controls what authentication method is using during the handshake AuthenticationType() CipherSuiteAuthenticationType // KeyExchangeAlgorithm controls what exchange algorithm is using during the handshake KeyExchangeAlgorithm() CipherSuiteKeyExchangeAlgorithm // ECC (Elliptic Curve Cryptography) determines whether ECC extesions will be send during handshake. // https://datatracker.ietf.org/doc/html/rfc4492#page-10 ECC() bool // Called when keying material has been generated, should initialize the internal cipher Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error IsInitialized() bool Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) Decrypt(h recordlayer.Header, in []byte) ([]byte, error) } // CipherSuiteName provides the same functionality as tls.CipherSuiteName // that appeared first in Go 1.14. // // Our implementation differs slightly in that it takes in a CiperSuiteID, // like the rest of our library, instead of a uint16 like crypto/tls. func CipherSuiteName(id CipherSuiteID) string { suite := cipherSuiteForID(id, nil) if suite != nil { return suite.String() } return fmt.Sprintf("0x%04X", uint16(id)) } // Taken from https://www.iana.org/assignments/tls-parameters/tls-parameters.xml // A cipherSuite is a specific combination of key agreement, cipher and MAC // function. func cipherSuiteForID(id CipherSuiteID, customCiphers func() []CipherSuite) CipherSuite { //nolint:cyclop switch id { //nolint:exhaustive case TLS_ECDHE_ECDSA_WITH_AES_128_CCM: return ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm() case TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8: return ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm8() case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: return &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{} case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: return &ciphersuite.TLSEcdheRsaWithAes128GcmSha256{} case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: return &ciphersuite.TLSEcdheEcdsaWithAes256CbcSha{} case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: return &ciphersuite.TLSEcdheRsaWithAes256CbcSha{} case TLS_PSK_WITH_AES_128_CCM: return ciphersuite.NewTLSPskWithAes128Ccm() case TLS_PSK_WITH_AES_128_CCM_8: return ciphersuite.NewTLSPskWithAes128Ccm8() case TLS_PSK_WITH_AES_256_CCM_8: return ciphersuite.NewTLSPskWithAes256Ccm8() case TLS_PSK_WITH_AES_128_GCM_SHA256: return &ciphersuite.TLSPskWithAes128GcmSha256{} case TLS_PSK_WITH_AES_128_CBC_SHA256: return &ciphersuite.TLSPskWithAes128CbcSha256{} case TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: return &ciphersuite.TLSEcdheEcdsaWithAes256GcmSha384{} case TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: return &ciphersuite.TLSEcdheRsaWithAes256GcmSha384{} case TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256: return ciphersuite.NewTLSEcdhePskWithAes128CbcSha256() } if customCiphers != nil { for _, c := range customCiphers() { if c.ID() == id { return c } } } return nil } // CipherSuites we support in order of preference. func defaultCipherSuites() []CipherSuite { return []CipherSuite{ &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}, &ciphersuite.TLSEcdheRsaWithAes128GcmSha256{}, &ciphersuite.TLSEcdheEcdsaWithAes256CbcSha{}, &ciphersuite.TLSEcdheRsaWithAes256CbcSha{}, &ciphersuite.TLSEcdheEcdsaWithAes256GcmSha384{}, &ciphersuite.TLSEcdheRsaWithAes256GcmSha384{}, } } func allCipherSuites() []CipherSuite { return []CipherSuite{ ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm(), ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm8(), &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}, &ciphersuite.TLSEcdheRsaWithAes128GcmSha256{}, &ciphersuite.TLSEcdheEcdsaWithAes256CbcSha{}, &ciphersuite.TLSEcdheRsaWithAes256CbcSha{}, ciphersuite.NewTLSPskWithAes128Ccm(), ciphersuite.NewTLSPskWithAes128Ccm8(), ciphersuite.NewTLSPskWithAes256Ccm8(), &ciphersuite.TLSPskWithAes128GcmSha256{}, &ciphersuite.TLSEcdheEcdsaWithAes256GcmSha384{}, &ciphersuite.TLSEcdheRsaWithAes256GcmSha384{}, } } func cipherSuiteIDs(cipherSuites []CipherSuite) []uint16 { rtrn := []uint16{} for _, c := range cipherSuites { rtrn = append(rtrn, uint16(c.ID())) } return rtrn } //nolint:cyclop func parseCipherSuites( userSelectedSuites []CipherSuiteID, customCipherSuites func() []CipherSuite, includeCertificateSuites, includePSKSuites bool, ) ([]CipherSuite, error) { cipherSuitesForIDs := func(ids []CipherSuiteID) ([]CipherSuite, error) { cipherSuites := []CipherSuite{} for _, id := range ids { c := cipherSuiteForID(id, nil) if c == nil { return nil, &invalidCipherSuiteError{id} } cipherSuites = append(cipherSuites, c) } return cipherSuites, nil } var ( cipherSuites []CipherSuite err error i int ) if userSelectedSuites != nil { cipherSuites, err = cipherSuitesForIDs(userSelectedSuites) if err != nil { return nil, err } } else { cipherSuites = defaultCipherSuites() } // Put CustomCipherSuites before ID selected suites if customCipherSuites != nil { cipherSuites = append(customCipherSuites(), cipherSuites...) } var foundCertificateSuite, foundPSKSuite, foundAnonymousSuite bool for _, c := range cipherSuites { switch { case includeCertificateSuites && c.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate: foundCertificateSuite = true case includePSKSuites && c.AuthenticationType() == CipherSuiteAuthenticationTypePreSharedKey: foundPSKSuite = true case c.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous: foundAnonymousSuite = true default: continue } cipherSuites[i] = c i++ } switch { case includeCertificateSuites && !foundCertificateSuite && !foundAnonymousSuite: return nil, errNoAvailableCertificateCipherSuite case includePSKSuites && !foundPSKSuite: return nil, errNoAvailablePSKCipherSuite case i == 0: return nil, errNoAvailableCipherSuites } return cipherSuites[:i], nil } func filterCipherSuitesForCertificate(cert *tls.Certificate, cipherSuites []CipherSuite) []CipherSuite { if cert == nil || cert.PrivateKey == nil { return cipherSuites } signer, ok := cert.PrivateKey.(crypto.Signer) if !ok { return cipherSuites } var certType clientcertificate.Type switch signer.Public().(type) { case ed25519.PublicKey, *ecdsa.PublicKey: certType = clientcertificate.ECDSASign case *rsa.PublicKey: certType = clientcertificate.RSASign } filtered := []CipherSuite{} for _, c := range cipherSuites { if c.AuthenticationType() != CipherSuiteAuthenticationTypeCertificate || certType == c.CertificateType() { filtered = append(filtered, c) } } return filtered } golang-github-pion-dtls-v3-3.0.7/cipher_suite_go114.go000066400000000000000000000022501507057460300224400ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build go1.14 // +build go1.14 package dtls import ( "crypto/tls" ) // VersionDTLS12 is the DTLS version in the same style as // VersionTLSXX from crypto/tls. const VersionDTLS12 = 0xfefd // Convert from our cipherSuite interface to a tls.CipherSuite struct. func toTLSCipherSuite(c CipherSuite) *tls.CipherSuite { return &tls.CipherSuite{ ID: uint16(c.ID()), Name: c.String(), SupportedVersions: []uint16{VersionDTLS12}, Insecure: false, } } // CipherSuites returns a list of cipher suites currently implemented by this // package, excluding those with security issues, which are returned by // InsecureCipherSuites. func CipherSuites() []*tls.CipherSuite { suites := allCipherSuites() res := make([]*tls.CipherSuite, len(suites)) for i, c := range suites { res[i] = toTLSCipherSuite(c) } return res } // InsecureCipherSuites returns a list of cipher suites currently implemented by // this package and which have security issues. func InsecureCipherSuites() []*tls.CipherSuite { var res []*tls.CipherSuite return res } golang-github-pion-dtls-v3-3.0.7/cipher_suite_go114_test.go000066400000000000000000000016521507057460300235040ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build go1.14 // +build go1.14 package dtls import ( "testing" "github.com/stretchr/testify/assert" ) func TestInsecureCipherSuites(t *testing.T) { assert.Empty(t, InsecureCipherSuites(), "Expected no insecure ciphersuites") } func TestCipherSuites(t *testing.T) { ours := allCipherSuites() theirs := CipherSuites() assert.Equal(t, len(ours), len(theirs)) for i, s := range ours { i := i s := s t.Run(s.String(), func(t *testing.T) { cipher := theirs[i] assert.Equal(t, cipher.ID, uint16(s.ID())) assert.Equal(t, cipher.Name, s.String()) assert.Equal(t, 1, len(cipher.SupportedVersions), "Expected SupportedVersion to be 1") assert.Equal(t, uint16(VersionDTLS12), cipher.SupportedVersions[0], "Expected SupportedVersion to match") assert.False(t, cipher.Insecure, "Expected Insecure") }) } } golang-github-pion-dtls-v3-3.0.7/cipher_suite_test.go000066400000000000000000000050741507057460300225730ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "testing" "time" "github.com/pion/dtls/v3/internal/ciphersuite" dtlsnet "github.com/pion/dtls/v3/pkg/net" "github.com/pion/transport/v3/dpipe" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/assert" ) func TestCipherSuiteName(t *testing.T) { testCases := []struct { suite CipherSuiteID expected string }{ {TLS_ECDHE_ECDSA_WITH_AES_128_CCM, "TLS_ECDHE_ECDSA_WITH_AES_128_CCM"}, {CipherSuiteID(0x0000), "0x0000"}, } for _, testCase := range testCases { assert.Equal(t, testCase.expected, CipherSuiteName(testCase.suite)) } } func TestAllCipherSuites(t *testing.T) { assert.NotEmpty(t, allCipherSuites()) } // CustomCipher that is just used to assert Custom IDs work. type testCustomCipherSuite struct { ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256 authenticationType CipherSuiteAuthenticationType } func (t *testCustomCipherSuite) ID() CipherSuiteID { return 0xFFFF } func (t *testCustomCipherSuite) AuthenticationType() CipherSuiteAuthenticationType { return t.authenticationType } // Assert that two connections that pass in a CipherSuite with a CustomID works. func TestCustomCipherSuite(t *testing.T) { type result struct { c *Conn err error } // Check for leaking routines report := test.CheckRoutines(t) defer report() runTest := func(cipherFactory func() []CipherSuite) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() resultCh := make(chan result) go func() { client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ CipherSuites: []CipherSuiteID{}, CustomCipherSuites: cipherFactory, }, true) resultCh <- result{client, err} }() server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ CipherSuites: []CipherSuiteID{}, CustomCipherSuites: cipherFactory, }, true) clientResult := <-resultCh assert.NoError(t, err) assert.NoError(t, server.Close()) assert.Nil(t, clientResult.err) assert.NoError(t, clientResult.c.Close()) } t.Run("Custom ID", func(*testing.T) { runTest(func() []CipherSuite { return []CipherSuite{&testCustomCipherSuite{authenticationType: CipherSuiteAuthenticationTypeCertificate}} }) }) t.Run("Anonymous Cipher", func(*testing.T) { runTest(func() []CipherSuite { return []CipherSuite{&testCustomCipherSuite{authenticationType: CipherSuiteAuthenticationTypeAnonymous}} }) }) } golang-github-pion-dtls-v3-3.0.7/codecov.yml000066400000000000000000000007151507057460300206640ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT coverage: status: project: default: # Allow decreasing 2% of total coverage to avoid noise. threshold: 2% patch: default: target: 70% only_pulls: true ignore: - "examples/*" - "examples/**/*" golang-github-pion-dtls-v3-3.0.7/compression_method.go000066400000000000000000000004261507057460300227460ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import "github.com/pion/dtls/v3/pkg/protocol" func defaultCompressionMethods() []*protocol.CompressionMethod { return []*protocol.CompressionMethod{ {}, } } golang-github-pion-dtls-v3-3.0.7/config.go000066400000000000000000000273131507057460300203160ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" "crypto/tls" "crypto/x509" "io" "net" "time" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/logging" ) const keyLogLabelTLS12 = "CLIENT_RANDOM" // Config is used to configure a DTLS client or server. // After a Config is passed to a DTLS function it must not be modified. type Config struct { // Certificates contains certificate chain to present to the other side of the connection. // Server MUST set this if PSK is non-nil // client SHOULD sets this so CertificateRequests can be handled if PSK is non-nil Certificates []tls.Certificate // CipherSuites is a list of supported cipher suites. // If CipherSuites is nil, a default list is used CipherSuites []CipherSuiteID // CustomCipherSuites is a list of CipherSuites that can be // provided by the user. This allow users to user Ciphers that are reserved // for private usage. CustomCipherSuites func() []CipherSuite // SignatureSchemes contains the signature and hash schemes that the peer requests to verify. SignatureSchemes []tls.SignatureScheme // SRTPProtectionProfiles are the supported protection profiles // Clients will send this via use_srtp and assert that the server properly responds // Servers will assert that clients send one of these profiles and will respond as needed SRTPProtectionProfiles []SRTPProtectionProfile // SRTPMasterKeyIdentifier value (if any) is sent via the use_srtp // extension for Clients and Servers SRTPMasterKeyIdentifier []byte // ClientAuth determines the server's policy for // TLS Client Authentication. The default is NoClientCert. ClientAuth ClientAuthType // RequireExtendedMasterSecret determines if the "Extended Master Secret" extension // should be disabled, requested, or required (default requested). ExtendedMasterSecret ExtendedMasterSecretType // FlightInterval controls how often we send outbound handshake messages // defaults to time.Second FlightInterval time.Duration // DisableRetransmitBackoff can be used to the disable the backoff feature // when sending outbound messages as specified in RFC 4347 4.2.4.1 DisableRetransmitBackoff bool // PSK sets the pre-shared key used by this DTLS connection // If PSK is non-nil only PSK CipherSuites will be used PSK PSKCallback PSKIdentityHint []byte // InsecureSkipVerify controls whether a client verifies the // server's certificate chain and host name. // If InsecureSkipVerify is true, TLS accepts any certificate // presented by the server and any host name in that certificate. // In this mode, TLS is susceptible to man-in-the-middle attacks. // This should be used only for testing. InsecureSkipVerify bool // InsecureHashes allows the use of hashing algorithms that are known // to be vulnerable. InsecureHashes bool // VerifyPeerCertificate, if not nil, is called after normal // certificate verification by either a client or server. It // receives the certificate provided by the peer and also a flag // that tells if normal verification has succeedded. If it returns a // non-nil error, the handshake is aborted and that error results. // // If normal verification fails then the handshake will abort before // considering this callback. If normal verification is disabled by // setting InsecureSkipVerify, or (for a server) when ClientAuth is // RequestClientCert or RequireAnyClientCert, then this callback will // be considered but the verifiedChains will always be nil. VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error // VerifyConnection, if not nil, is called after normal certificate // verification/PSK and after VerifyPeerCertificate by either a TLS client // or server. If it returns a non-nil error, the handshake is aborted // and that error results. // // If normal verification fails then the handshake will abort before // considering this callback. This callback will run for all connections // regardless of InsecureSkipVerify or ClientAuth settings. VerifyConnection func(*State) error // RootCAs defines the set of root certificate authorities // that one peer uses when verifying the other peer's certificates. // If RootCAs is nil, TLS uses the host's root CA set. RootCAs *x509.CertPool // ClientCAs defines the set of root certificate authorities // that servers use if required to verify a client certificate // by the policy in ClientAuth. ClientCAs *x509.CertPool // ServerName is used to verify the hostname on the returned // certificates unless InsecureSkipVerify is given. ServerName string LoggerFactory logging.LoggerFactory // MTU is the length at which handshake messages will be fragmented to // fit within the maximum transmission unit (default is 1200 bytes) MTU int // ReplayProtectionWindow is the size of the replay attack protection window. // Duplication of the sequence number is checked in this window size. // Packet with sequence number older than this value compared to the latest // accepted packet will be discarded. (default is 64) ReplayProtectionWindow int // KeyLogWriter optionally specifies a destination for TLS master secrets // in NSS key log format that can be used to allow external programs // such as Wireshark to decrypt TLS connections. // See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format. // Use of KeyLogWriter compromises security and should only be // used for debugging. KeyLogWriter io.Writer // SessionStore is the container to store session for resumption. SessionStore SessionStore // List of application protocols the peer supports, for ALPN SupportedProtocols []string // List of Elliptic Curves to use // // If an ECC ciphersuite is configured and EllipticCurves is empty // it will default to X25519, P-256, P-384 in this specific order. EllipticCurves []elliptic.Curve // GetCertificate returns a Certificate based on the given // ClientHelloInfo. It will only be called if the client supplies SNI // information or if Certificates is empty. // // If GetCertificate is nil or returns nil, then the certificate is // retrieved from NameToCertificate. If NameToCertificate is nil, the // best element of Certificates will be used. GetCertificate func(*ClientHelloInfo) (*tls.Certificate, error) // GetClientCertificate, if not nil, is called when a server requests a // certificate from a client. If set, the contents of Certificates will // be ignored. // // If GetClientCertificate returns an error, the handshake will be // aborted and that error will be returned. Otherwise // GetClientCertificate must return a non-nil Certificate. If // Certificate.Certificate is empty then no certificate will be sent to // the server. If this is unacceptable to the server then it may abort // the handshake. GetClientCertificate func(*CertificateRequestInfo) (*tls.Certificate, error) // InsecureSkipVerifyHello, if true and when acting as server, allow client to // skip hello verify phase and receive ServerHello after initial ClientHello. // This have implication on DoS attack resistance. InsecureSkipVerifyHello bool // ConnectionIDGenerator generates connection identifiers that should be // sent by the remote party if it supports the DTLS Connection Identifier // extension, as determined during the handshake. Generated connection // identifiers must always have the same length. Returning a zero-length // connection identifier indicates that the local party supports sending // connection identifiers but does not require the remote party to send // them. A nil ConnectionIDGenerator indicates that connection identifiers // are not supported. // https://datatracker.ietf.org/doc/html/rfc9146 ConnectionIDGenerator func() []byte // PaddingLengthGenerator generates the number of padding bytes used to // inflate ciphertext size in order to obscure content size from observers. // The length of the content is passed to the generator such that both // deterministic and random padding schemes can be applied while not // exceeding maximum record size. // If no PaddingLengthGenerator is specified, padding will not be applied. // https://datatracker.ietf.org/doc/html/rfc9146#section-4 PaddingLengthGenerator func(uint) uint // HelloRandomBytesGenerator generates custom client hello random bytes. HelloRandomBytesGenerator func() [handshake.RandomBytesLength]byte // Handshake hooks: hooks can be used for testing invalid messages, // mimicking other implementations or randomizing fields, which is valuable // for applications that need censorship-resistance by making // fingerprinting more difficult. // ClientHelloMessageHook, if not nil, is called when a Client Hello message is sent // from a client. The returned handshake message replaces the original message. ClientHelloMessageHook func(handshake.MessageClientHello) handshake.Message // ServerHelloMessageHook, if not nil, is called when a Server Hello message is sent // from a server. The returned handshake message replaces the original message. ServerHelloMessageHook func(handshake.MessageServerHello) handshake.Message // CertificateRequestMessageHook, if not nil, is called when a Certificate Request // message is sent from a server. The returned handshake message replaces the original message. CertificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message // OnConnectionAttempt is fired Whenever a connection attempt is made, // the server or application can call this callback function. // The callback function can then implement logic to handle the connection attempt, such as logging the attempt, // checking against a list of blocked IPs, or counting the attempts to prevent brute force attacks. // If the callback function returns an error, the connection attempt will be aborted. OnConnectionAttempt func(net.Addr) error } func (c *Config) includeCertificateSuites() bool { return c.PSK == nil || len(c.Certificates) > 0 || c.GetCertificate != nil || c.GetClientCertificate != nil } const defaultMTU = 1200 // bytes var defaultCurves = []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384} //nolint:gochecknoglobals // PSKCallback is called once we have the remote's PSKIdentityHint. // If the remote provided none it will be nil. type PSKCallback func([]byte) ([]byte, error) // ClientAuthType declares the policy the server will follow for // TLS Client Authentication. type ClientAuthType int // ClientAuthType enums. const ( NoClientCert ClientAuthType = iota RequestClientCert RequireAnyClientCert VerifyClientCertIfGiven RequireAndVerifyClientCert ) // ExtendedMasterSecretType declares the policy the client and server // will follow for the Extended Master Secret extension. type ExtendedMasterSecretType int // ExtendedMasterSecretType enums. const ( RequestExtendedMasterSecret ExtendedMasterSecretType = iota RequireExtendedMasterSecret DisableExtendedMasterSecret ) func validateConfig(config *Config) error { //nolint:cyclop switch { case config == nil: return errNoConfigProvided case config.PSKIdentityHint != nil && config.PSK == nil: return errIdentityNoPSK } for _, cert := range config.Certificates { if cert.Certificate == nil { return errInvalidCertificate } if cert.PrivateKey != nil { signer, ok := cert.PrivateKey.(crypto.Signer) if !ok { return errInvalidPrivateKey } switch signer.Public().(type) { case ed25519.PublicKey: case *ecdsa.PublicKey: case *rsa.PublicKey: default: return errInvalidPrivateKey } } } _, err := parseCipherSuites( config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil, ) return err } golang-github-pion-dtls-v3-3.0.7/config_test.go000066400000000000000000000104661507057460300213560ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "crypto/dsa" //nolint:staticcheck "crypto/rand" "crypto/rsa" "crypto/tls" "errors" "testing" "github.com/pion/dtls/v3/pkg/crypto/selfsign" "github.com/stretchr/testify/assert" ) func TestValidateConfig(t *testing.T) { cert, err := selfsign.GenerateSelfSigned() if err != nil { assert.NoError(t, err, "TestValidateConfig: Config validation error, self signed certificate not generated") return } dsaPrivateKey := &dsa.PrivateKey{} err = dsa.GenerateParameters(&dsaPrivateKey.Parameters, rand.Reader, dsa.L1024N160) if err != nil { assert.NoError(t, err, "TestValidateConfig: Config validation error, DSA parameters not generated") return } err = dsa.GenerateKey(dsaPrivateKey, rand.Reader) if err != nil { assert.NoError(t, err, "TestValidateConfig: Config validation error, DSA private key not generated") return } rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { assert.NoError(t, err, "TestValidateConfig: Config validation error, RSA private key not generated") return } cases := map[string]struct { config *Config wantAnyErr bool expErr error }{ "Empty config": { expErr: errNoConfigProvided, }, "PSK and Certificate, valid cipher suites": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, PSK: func([]byte) ([]byte, error) { return nil, nil }, Certificates: []tls.Certificate{cert}, }, }, "PSK and Certificate, no PSK cipher suite": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, PSK: func([]byte) ([]byte, error) { return nil, nil }, Certificates: []tls.Certificate{cert}, }, expErr: errNoAvailablePSKCipherSuite, }, "PSK and Certificate, no non-PSK cipher suite": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, PSK: func([]byte) ([]byte, error) { return nil, nil }, Certificates: []tls.Certificate{cert}, }, expErr: errNoAvailableCertificateCipherSuite, }, "PSK identity hint with not PSK": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, PSK: nil, PSKIdentityHint: []byte{}, }, expErr: errIdentityNoPSK, }, "Invalid private key": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, Certificates: []tls.Certificate{{Certificate: cert.Certificate, PrivateKey: dsaPrivateKey}}, }, expErr: errInvalidPrivateKey, }, "PrivateKey without Certificate": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, Certificates: []tls.Certificate{{PrivateKey: cert.PrivateKey}}, }, expErr: errInvalidCertificate, }, "Invalid cipher suites": { config: &Config{CipherSuites: []CipherSuiteID{0x0000}}, wantAnyErr: true, }, "Valid config": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, Certificates: []tls.Certificate{cert, {Certificate: cert.Certificate, PrivateKey: rsaPrivateKey}}, }, }, "Valid config with get certificate": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, GetCertificate: func(*ClientHelloInfo) (*tls.Certificate, error) { return &tls.Certificate{Certificate: cert.Certificate, PrivateKey: rsaPrivateKey}, nil }, }, }, "Valid config with get client certificate": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, GetClientCertificate: func(*CertificateRequestInfo) (*tls.Certificate, error) { return &tls.Certificate{Certificate: cert.Certificate, PrivateKey: rsaPrivateKey}, nil }, }, }, } for name, testCase := range cases { testCase := testCase t.Run(name, func(t *testing.T) { err := validateConfig(testCase.config) if testCase.expErr != nil || testCase.wantAnyErr { if testCase.expErr != nil && !errors.Is(err, testCase.expErr) { assert.ErrorIs(t, err, testCase.expErr, "TestValidateConfig") } assert.Error(t, err, "TestValidateConfig: Config validation expected an error") } }) } } golang-github-pion-dtls-v3-3.0.7/conn.go000066400000000000000000001064411507057460300200060ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "bytes" "context" "errors" "fmt" "io" "net" "sync" "sync/atomic" "time" "github.com/pion/dtls/v3/internal/closer" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/crypto/signaturehash" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" "github.com/pion/logging" "github.com/pion/transport/v3/deadline" "github.com/pion/transport/v3/netctx" "github.com/pion/transport/v3/replaydetector" ) const ( initialTickerInterval = time.Second cookieLength = 20 sessionLength = 32 defaultNamedCurve = elliptic.X25519 inboundBufferSize = 8192 // Default replay protection window is specified by RFC 6347 Section 4.1.2.6. defaultReplayProtectionWindow = 64 // maxAppDataPacketQueueSize is the maximum number of app data packets we will. // enqueue before the handshake is completed. maxAppDataPacketQueueSize = 100 ) func invalidKeyingLabels() map[string]bool { return map[string]bool{ "client finished": true, "server finished": true, "master secret": true, "key expansion": true, } } type addrPkt struct { rAddr net.Addr data []byte } type recvHandshakeState struct { done chan struct{} isRetransmit bool } // Conn represents a DTLS connection. type Conn struct { lock sync.RWMutex // Internal lock (must not be public) nextConn netctx.PacketConn // Embedded Conn, typically a udpconn we read/write from fragmentBuffer *fragmentBuffer // out-of-order and missing fragment handling handshakeCache *handshakeCache // caching of handshake messages for verifyData generation decrypted chan any // Decrypted Application Data or error, pull by calling `Read` rAddr net.Addr state State // Internal state maximumTransmissionUnit int paddingLengthGenerator func(uint) uint handshakeCompletedSuccessfully atomic.Bool handshakeMutex sync.Mutex handshakeDone chan struct{} encryptedPackets []addrPkt connectionClosedByUser bool closeLock sync.Mutex closed *closer.Closer readDeadline *deadline.Deadline writeDeadline *deadline.Deadline log logging.LeveledLogger reading chan struct{} handshakeRecv chan recvHandshakeState cancelHandshaker func() cancelHandshakeReader func() fsm *handshakeFSM replayProtectionWindow uint handshakeConfig *handshakeConfig } //nolint:cyclop func createConn( nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool, resumeState *State, ) (*Conn, error) { if err := validateConfig(config); err != nil { return nil, err } if nextConn == nil { return nil, errNilNextConn } loggerFactory := config.LoggerFactory if loggerFactory == nil { loggerFactory = logging.NewDefaultLoggerFactory() } logger := loggerFactory.NewLogger("dtls") mtu := config.MTU if mtu <= 0 { mtu = defaultMTU } replayProtectionWindow := config.ReplayProtectionWindow if replayProtectionWindow <= 0 { replayProtectionWindow = defaultReplayProtectionWindow } paddingLengthGenerator := config.PaddingLengthGenerator if paddingLengthGenerator == nil { paddingLengthGenerator = func(uint) uint { return 0 } } cipherSuites, err := parseCipherSuites( config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil, ) if err != nil { return nil, err } signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes) if err != nil { return nil, err } workerInterval := initialTickerInterval if config.FlightInterval != 0 { workerInterval = config.FlightInterval } serverName := config.ServerName // Do not allow the use of an IP address literal as an SNI value. // See RFC 6066, Section 3. if net.ParseIP(serverName) != nil { serverName = "" } curves := config.EllipticCurves if len(curves) == 0 { curves = defaultCurves } handshakeConfig := &handshakeConfig{ localPSKCallback: config.PSK, localPSKIdentityHint: config.PSKIdentityHint, localCipherSuites: cipherSuites, localSignatureSchemes: signatureSchemes, extendedMasterSecret: config.ExtendedMasterSecret, localSRTPProtectionProfiles: config.SRTPProtectionProfiles, localSRTPMasterKeyIdentifier: config.SRTPMasterKeyIdentifier, serverName: serverName, supportedProtocols: config.SupportedProtocols, clientAuth: config.ClientAuth, localCertificates: config.Certificates, insecureSkipVerify: config.InsecureSkipVerify, verifyPeerCertificate: config.VerifyPeerCertificate, verifyConnection: config.VerifyConnection, rootCAs: config.RootCAs, clientCAs: config.ClientCAs, customCipherSuites: config.CustomCipherSuites, initialRetransmitInterval: workerInterval, disableRetransmitBackoff: config.DisableRetransmitBackoff, log: logger, initialEpoch: 0, keyLogWriter: config.KeyLogWriter, sessionStore: config.SessionStore, ellipticCurves: curves, localGetCertificate: config.GetCertificate, localGetClientCertificate: config.GetClientCertificate, insecureSkipHelloVerify: config.InsecureSkipVerifyHello, connectionIDGenerator: config.ConnectionIDGenerator, helloRandomBytesGenerator: config.HelloRandomBytesGenerator, clientHelloMessageHook: config.ClientHelloMessageHook, serverHelloMessageHook: config.ServerHelloMessageHook, certificateRequestMessageHook: config.CertificateRequestMessageHook, resumeState: resumeState, } conn := &Conn{ rAddr: rAddr, nextConn: netctx.NewPacketConn(nextConn), handshakeConfig: handshakeConfig, fragmentBuffer: newFragmentBuffer(), handshakeCache: newHandshakeCache(), maximumTransmissionUnit: mtu, paddingLengthGenerator: paddingLengthGenerator, decrypted: make(chan any, 1), log: logger, readDeadline: deadline.New(), writeDeadline: deadline.New(), reading: make(chan struct{}, 1), handshakeRecv: make(chan recvHandshakeState), closed: closer.NewCloser(), cancelHandshaker: func() {}, cancelHandshakeReader: func() {}, replayProtectionWindow: uint(replayProtectionWindow), //nolint:gosec // G115 state: State{ isClient: isClient, }, } conn.setRemoteEpoch(0) conn.setLocalEpoch(0) return conn, nil } // Handshake runs the client or server DTLS handshake // protocol if it has not yet been run. // // Most uses of this package need not call Handshake explicitly: the // first [Conn.Read] or [Conn.Write] will call it automatically. // // For control over canceling or setting a timeout on a handshake, use // [Conn.HandshakeContext]. func (c *Conn) Handshake() error { return c.HandshakeContext(context.Background()) } // HandshakeContext runs the client or server DTLS handshake // protocol if it has not yet been run. // // The provided Context must be non-nil. If the context is canceled before // the handshake is complete, the handshake is interrupted and an error is returned. // Once the handshake has completed, cancellation of the context will not affect the // connection. // // Most uses of this package need not call HandshakeContext explicitly: the // first [Conn.Read] or [Conn.Write] will call it automatically. func (c *Conn) HandshakeContext(ctx context.Context) error { c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() if c.isHandshakeCompletedSuccessfully() { return nil } handshakeDone := make(chan struct{}) defer close(handshakeDone) c.closeLock.Lock() c.handshakeDone = handshakeDone c.closeLock.Unlock() // rfc5246#section-7.4.3 // In addition, the hash and signature algorithms MUST be compatible // with the key in the server's end-entity certificate. if !c.state.isClient { cert, err := c.handshakeConfig.getCertificate(&ClientHelloInfo{}) if err != nil && !errors.Is(err, errNoCertificates) { return err } c.handshakeConfig.localCipherSuites = filterCipherSuitesForCertificate(cert, c.handshakeConfig.localCipherSuites) } var initialFlight flightVal var initialFSMState handshakeState if c.handshakeConfig.resumeState != nil { //nolint:nestif if c.state.isClient { initialFlight = flight5 } else { initialFlight = flight6 } initialFSMState = handshakeFinished c.state = *c.handshakeConfig.resumeState } else { if c.state.isClient { initialFlight = flight1 } else { initialFlight = flight0 } initialFSMState = handshakePreparing } // Do handshake if err := c.handshake(ctx, c.handshakeConfig, initialFlight, initialFSMState); err != nil { return err } c.log.Trace("Handshake Completed") return nil } // Dial connects to the given network address and establishes a DTLS connection on top. func Dial(network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) { // net.ListenUDP is used rather than net.DialUDP as the latter prevents the // use of net.PacketConn.WriteTo. // https://github.com/golang/go/blob/ce5e37ec21442c6eb13a43e68ca20129102ebac0/src/net/udpsock_posix.go#L115 pConn, err := net.ListenUDP(network, nil) if err != nil { return nil, err } return Client(pConn, rAddr, config) } // Client establishes a DTLS connection over an existing connection. func Client(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { switch { case config == nil: return nil, errNoConfigProvided case config.PSK != nil && config.PSKIdentityHint == nil: return nil, errPSKAndIdentityMustBeSetForClient } return createConn(conn, rAddr, config, true, nil) } // Server listens for incoming DTLS connections. func Server(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { if config == nil { return nil, errNoConfigProvided } if config.OnConnectionAttempt != nil { if err := config.OnConnectionAttempt(rAddr); err != nil { return nil, err } } return createConn(conn, rAddr, config, false, nil) } // Read reads data from the connection. func (c *Conn) Read(buff []byte) (n int, err error) { //nolint:cyclop if err := c.Handshake(); err != nil { return 0, err } select { case <-c.readDeadline.Done(): return 0, errDeadlineExceeded default: } for { select { case <-c.readDeadline.Done(): return 0, errDeadlineExceeded case out, ok := <-c.decrypted: if !ok { return 0, io.EOF } switch val := out.(type) { case ([]byte): if len(buff) < len(val) { return 0, errBufferTooSmall } copy(buff, val) return len(val), nil case (error): return 0, val } } } } // Write writes len(payload) bytes from payload to the DTLS connection. func (c *Conn) Write(payload []byte) (int, error) { if c.isConnectionClosed() { return 0, ErrConnClosed } select { case <-c.writeDeadline.Done(): return 0, errDeadlineExceeded default: } if err := c.Handshake(); err != nil { return 0, err } return len(payload), c.writePackets(c.writeDeadline, []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Epoch: c.state.getLocalEpoch(), Version: protocol.Version1_2, }, Content: &protocol.ApplicationData{ Data: payload, }, }, shouldWrapCID: len(c.state.remoteConnectionID) > 0, shouldEncrypt: true, }, }) } // Close closes the connection. func (c *Conn) Close() error { err := c.close(true) //nolint:contextcheck c.closeLock.Lock() handshakeDone := c.handshakeDone c.closeLock.Unlock() if handshakeDone != nil { <-handshakeDone } return err } // ConnectionState returns basic DTLS details about the connection. // Note that this replaced the `Export` function of v1. func (c *Conn) ConnectionState() (State, bool) { c.lock.RLock() defer c.lock.RUnlock() stateClone, err := c.state.clone() if err != nil { return State{}, false } return *stateClone, true } // SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile. func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) { profile := c.state.getSRTPProtectionProfile() if profile == 0 { return 0, false } return profile, true } // RemoteSRTPMasterKeyIdentifier returns the MasterKeyIdentifier value from the use_srtp. func (c *Conn) RemoteSRTPMasterKeyIdentifier() ([]byte, bool) { if profile := c.state.getSRTPProtectionProfile(); profile == 0 { return nil, false } return c.state.remoteSRTPMasterKeyIdentifier, true } func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error { c.lock.Lock() defer c.lock.Unlock() var rawPackets [][]byte for _, pkt := range pkts { if dtlsHandshake, ok := pkt.record.Content.(*handshake.Handshake); ok { handshakeRaw, err := pkt.record.Marshal() if err != nil { return err } c.log.Tracef("[handshake:%v] -> %s (epoch: %d, seq: %d)", srvCliStr(c.state.isClient), dtlsHandshake.Header.Type.String(), pkt.record.Header.Epoch, dtlsHandshake.Header.MessageSequence) c.handshakeCache.push( handshakeRaw[recordlayer.FixedHeaderSize:], pkt.record.Header.Epoch, dtlsHandshake.Header.MessageSequence, dtlsHandshake.Header.Type, c.state.isClient, ) rawHandshakePackets, err := c.processHandshakePacket(pkt, dtlsHandshake) if err != nil { return err } rawPackets = append(rawPackets, rawHandshakePackets...) } else { rawPacket, err := c.processPacket(pkt) if err != nil { return err } rawPackets = append(rawPackets, rawPacket) } } if len(rawPackets) == 0 { return nil } compactedRawPackets := c.compactRawPackets(rawPackets) for _, compactedRawPackets := range compactedRawPackets { if _, err := c.nextConn.WriteToContext(ctx, compactedRawPackets, c.rAddr); err != nil { return netError(err) } } return nil } func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte { // avoid a useless copy in the common case if len(rawPackets) == 1 { return rawPackets } combinedRawPackets := make([][]byte, 0) currentCombinedRawPacket := make([]byte, 0) for _, rawPacket := range rawPackets { if len(currentCombinedRawPacket) > 0 && len(currentCombinedRawPacket)+len(rawPacket) >= c.maximumTransmissionUnit { combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket) currentCombinedRawPacket = []byte{} } currentCombinedRawPacket = append(currentCombinedRawPacket, rawPacket...) } combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket) return combinedRawPackets } func (c *Conn) processPacket(pkt *packet) ([]byte, error) { //nolint:cyclop epoch := pkt.record.Header.Epoch for len(c.state.localSequenceNumber) <= int(epoch) { c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0)) } seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1 if seq > recordlayer.MaxSequenceNumber { // RFC 6347 Section 4.1.0 // The implementation must either abandon an association or rehandshake // prior to allowing the sequence number to wrap. return nil, errSequenceNumberOverflow } pkt.record.Header.SequenceNumber = seq var rawPacket []byte if pkt.shouldWrapCID { //nolint:nestif // Record must be marshaled to populate fields used in inner plaintext. if _, err := pkt.record.Marshal(); err != nil { return nil, err } content, err := pkt.record.Content.Marshal() if err != nil { return nil, err } inner := &recordlayer.InnerPlaintext{ Content: content, RealType: pkt.record.Header.ContentType, } rawInner, err := inner.Marshal() //nolint:govet if err != nil { return nil, err } cidHeader := &recordlayer.Header{ Version: pkt.record.Header.Version, ContentType: protocol.ContentTypeConnectionID, Epoch: pkt.record.Header.Epoch, ContentLen: uint16(len(rawInner)), //nolint:gosec //G115 ConnectionID: c.state.remoteConnectionID, SequenceNumber: pkt.record.Header.SequenceNumber, } rawPacket, err = cidHeader.Marshal() if err != nil { return nil, err } pkt.record.Header = *cidHeader rawPacket = append(rawPacket, rawInner...) } else { var err error rawPacket, err = pkt.record.Marshal() if err != nil { return nil, err } } if pkt.shouldEncrypt { var err error rawPacket, err = c.state.cipherSuite.Encrypt(pkt.record, rawPacket) if err != nil { return nil, err } } return rawPacket, nil } //nolint:cyclop func (c *Conn) processHandshakePacket(pkt *packet, dtlsHandshake *handshake.Handshake) ([][]byte, error) { rawPackets := make([][]byte, 0) handshakeFragments, err := c.fragmentHandshake(dtlsHandshake) if err != nil { return nil, err } epoch := pkt.record.Header.Epoch for len(c.state.localSequenceNumber) <= int(epoch) { c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0)) } for _, handshakeFragment := range handshakeFragments { seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1 if seq > recordlayer.MaxSequenceNumber { return nil, errSequenceNumberOverflow } var rawPacket []byte if pkt.shouldWrapCID { inner := &recordlayer.InnerPlaintext{ Content: handshakeFragment, RealType: protocol.ContentTypeHandshake, Zeros: c.paddingLengthGenerator(uint(len(handshakeFragment))), } rawInner, err := inner.Marshal() //nolint:govet if err != nil { return nil, err } cidHeader := &recordlayer.Header{ Version: pkt.record.Header.Version, ContentType: protocol.ContentTypeConnectionID, Epoch: pkt.record.Header.Epoch, ContentLen: uint16(len(rawInner)), //nolint:gosec //G115 ConnectionID: c.state.remoteConnectionID, SequenceNumber: pkt.record.Header.SequenceNumber, } rawPacket, err = cidHeader.Marshal() if err != nil { return nil, err } pkt.record.Header = *cidHeader rawPacket = append(rawPacket, rawInner...) } else { recordlayerHeader := &recordlayer.Header{ Version: pkt.record.Header.Version, ContentType: pkt.record.Header.ContentType, ContentLen: uint16(len(handshakeFragment)), //nolint:gosec // G115 Epoch: pkt.record.Header.Epoch, SequenceNumber: seq, } rawPacket, err = recordlayerHeader.Marshal() if err != nil { return nil, err } pkt.record.Header = *recordlayerHeader rawPacket = append(rawPacket, handshakeFragment...) } if pkt.shouldEncrypt { var err error rawPacket, err = c.state.cipherSuite.Encrypt(pkt.record, rawPacket) if err != nil { return nil, err } } rawPackets = append(rawPackets, rawPacket) } return rawPackets, nil } func (c *Conn) fragmentHandshake(dtlsHandshake *handshake.Handshake) ([][]byte, error) { content, err := dtlsHandshake.Message.Marshal() if err != nil { return nil, err } fragmentedHandshakes := make([][]byte, 0) contentFragments := splitBytes(content, c.maximumTransmissionUnit) if len(contentFragments) == 0 { contentFragments = [][]byte{ {}, } } offset := 0 for _, contentFragment := range contentFragments { contentFragmentLen := len(contentFragment) headerFragment := &handshake.Header{ Type: dtlsHandshake.Header.Type, Length: dtlsHandshake.Header.Length, MessageSequence: dtlsHandshake.Header.MessageSequence, FragmentOffset: uint32(offset), //nolint:gosec // G115 FragmentLength: uint32(contentFragmentLen), //nolint:gosec // G115 } offset += contentFragmentLen fragmentedHandshake, err := headerFragment.Marshal() if err != nil { return nil, err } fragmentedHandshake = append(fragmentedHandshake, contentFragment...) fragmentedHandshakes = append(fragmentedHandshakes, fragmentedHandshake) } return fragmentedHandshakes, nil } var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals New: func() any { b := make([]byte, inboundBufferSize) return &b }, } func (c *Conn) readAndBuffer(ctx context.Context) error { //nolint:cyclop bufptr, ok := poolReadBuffer.Get().(*[]byte) if !ok { return errFailedToAccessPoolReadBuffer } defer poolReadBuffer.Put(bufptr) b := *bufptr i, rAddr, err := c.nextConn.ReadFromContext(ctx, b) if err != nil { return netError(err) } pkts, err := recordlayer.ContentAwareUnpackDatagram(b[:i], len(c.state.getLocalConnectionID())) if err != nil { return err } var hasHandshake, isRetransmit bool for _, p := range pkts { hs, rtx, alert, err := c.handleIncomingPacket(ctx, p, rAddr, true) if alert != nil { if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { if err == nil { err = alertErr } } } var e *alertError if errors.As(err, &e) && e.IsFatalOrCloseNotify() { return e } if err != nil { return err } if hs { hasHandshake = true } if rtx { isRetransmit = true } } if hasHandshake { s := recvHandshakeState{ done: make(chan struct{}), isRetransmit: isRetransmit, } select { case c.handshakeRecv <- s: // If the other party may retransmit the flight, // we should respond even if it not a new message. <-s.done case <-c.fsm.Done(): } } return nil } func (c *Conn) handleQueuedPackets(ctx context.Context) error { pkts := c.encryptedPackets c.encryptedPackets = nil for _, p := range pkts { _, _, alert, err := c.handleIncomingPacket(ctx, p.data, p.rAddr, false) // don't re-enqueue if alert != nil { if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { if err == nil { err = alertErr } } } var e *alertError if errors.As(err, &e) && e.IsFatalOrCloseNotify() { return e } if err != nil { return err } } return nil } func (c *Conn) enqueueEncryptedPackets(packet addrPkt) bool { if len(c.encryptedPackets) < maxAppDataPacketQueueSize { c.encryptedPackets = append(c.encryptedPackets, packet) return true } return false } //nolint:gocognit,gocyclo,cyclop,maintidx func (c *Conn) handleIncomingPacket( ctx context.Context, buf []byte, rAddr net.Addr, enqueue bool, ) (bool, bool, *alert.Alert, error) { header := &recordlayer.Header{} // Set connection ID size so that records of content type tls12_cid will // be parsed correctly. if len(c.state.getLocalConnectionID()) > 0 { header.ConnectionID = make([]byte, len(c.state.getLocalConnectionID())) } if err := header.Unmarshal(buf); err != nil { // Decode error must be silently discarded // [RFC6347 Section-4.1.2.7] c.log.Debugf("discarded broken packet: %v", err) return false, false, nil, nil } // Validate epoch remoteEpoch := c.state.getRemoteEpoch() if header.Epoch > remoteEpoch { if header.Epoch > remoteEpoch+1 { c.log.Debugf("discarded future packet (epoch: %d, seq: %d)", header.Epoch, header.SequenceNumber, ) return false, false, nil, nil } if enqueue { if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok { c.log.Debug("received packet of next epoch, queuing packet") } } return false, false, nil, nil } // Anti-replay protection for len(c.state.replayDetector) <= int(header.Epoch) { c.state.replayDetector = append(c.state.replayDetector, replaydetector.New(c.replayProtectionWindow, recordlayer.MaxSequenceNumber), ) } markPacketAsValid, ok := c.state.replayDetector[int(header.Epoch)].Check(header.SequenceNumber) if !ok { c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)", header.Epoch, header.SequenceNumber, ) return false, false, nil, nil } // originalCID indicates whether the original record had content type // Connection ID. originalCID := false // Decrypt if header.Epoch != 0 { //nolint:nestif if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() { if enqueue { if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok { c.log.Debug("handshake not finished, queuing packet") } } return false, false, nil, nil } // If a connection identifier had been negotiated and encryption is // enabled, the connection identifier MUST be sent. if len(c.state.getLocalConnectionID()) > 0 && header.ContentType != protocol.ContentTypeConnectionID { c.log.Debug("discarded packet missing connection ID after value negotiated") return false, false, nil, nil } var err error var hdr recordlayer.Header if header.ContentType == protocol.ContentTypeConnectionID { hdr.ConnectionID = make([]byte, len(c.state.getLocalConnectionID())) } buf, err = c.state.cipherSuite.Decrypt(hdr, buf) if err != nil { c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err) return false, false, nil, nil } // If this is a connection ID record, make it look like a normal record for // further processing. if header.ContentType == protocol.ContentTypeConnectionID { originalCID = true ip := &recordlayer.InnerPlaintext{} if err := ip.Unmarshal(buf[header.Size():]); err != nil { //nolint:govet c.log.Debugf("unpacking inner plaintext failed: %s", err) return false, false, nil, nil } unpacked := &recordlayer.Header{ ContentType: ip.RealType, ContentLen: uint16(len(ip.Content)), //nolint:gosec // G115 Version: header.Version, Epoch: header.Epoch, SequenceNumber: header.SequenceNumber, } buf, err = unpacked.Marshal() if err != nil { c.log.Debugf("converting CID record to inner plaintext failed: %s", err) return false, false, nil, nil } buf = append(buf, ip.Content...) } // If connection ID does not match discard the packet. if !bytes.Equal(c.state.getLocalConnectionID(), header.ConnectionID) { c.log.Debug("unexpected connection ID") return false, false, nil, nil } } isHandshake, isRetransmit, err := c.fragmentBuffer.push(append([]byte{}, buf...)) if err != nil { // Decode error must be silently discarded // [RFC6347 Section-4.1.2.7] c.log.Debugf("defragment failed: %s", err) return false, false, nil, nil } else if isHandshake { markPacketAsValid() for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() { header := &handshake.Header{} if err := header.Unmarshal(out); err != nil { c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err) continue } c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient) } return true, isRetransmit, nil, nil } r := &recordlayer.RecordLayer{} if err := r.Unmarshal(buf); err != nil { return false, false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err } isLatestSeqNum := false switch content := r.Content.(type) { case *alert.Alert: c.log.Tracef("%s: <- %s", srvCliStr(c.state.isClient), content.String()) var a *alert.Alert if content.Description == alert.CloseNotify { // Respond with a close_notify [RFC5246 Section 7.2.1] a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify} } _ = markPacketAsValid() return false, false, a, &alertError{content} case *protocol.ChangeCipherSpec: if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() { if enqueue { if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok { c.log.Debugf("CipherSuite not initialized, queuing packet") } } return false, false, nil, nil } newRemoteEpoch := header.Epoch + 1 c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch) if c.state.getRemoteEpoch()+1 == newRemoteEpoch { c.setRemoteEpoch(newRemoteEpoch) isLatestSeqNum = markPacketAsValid() } case *protocol.ApplicationData: if header.Epoch == 0 { return false, false, &alert.Alert{ Level: alert.Fatal, Description: alert.UnexpectedMessage, }, errApplicationDataEpochZero } isLatestSeqNum = markPacketAsValid() select { case c.decrypted <- content.Data: case <-c.closed.Done(): case <-ctx.Done(): } default: return false, false, &alert.Alert{ Level: alert.Fatal, Description: alert.UnexpectedMessage, }, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType()) } // Any valid connection ID record is a candidate for updating the remote // address if it is the latest record received. // https://datatracker.ietf.org/doc/html/rfc9146#peer-address-update if originalCID && isLatestSeqNum { if rAddr != c.RemoteAddr() { c.lock.Lock() c.rAddr = rAddr c.lock.Unlock() } } return false, false, nil, nil } func (c *Conn) recvHandshake() <-chan recvHandshakeState { return c.handshakeRecv } func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error { if level == alert.Fatal && len(c.state.SessionID) > 0 { // According to the RFC, we need to delete the stored session. // https://datatracker.ietf.org/doc/html/rfc5246#section-7.2 if ss := c.fsm.cfg.sessionStore; ss != nil { c.log.Tracef("clean invalid session: %s", c.state.SessionID) if err := ss.Del(c.sessionKey()); err != nil { return err } } } return c.writePackets(ctx, []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Epoch: c.state.getLocalEpoch(), Version: protocol.Version1_2, }, Content: &alert.Alert{ Level: level, Description: desc, }, }, shouldWrapCID: len(c.state.remoteConnectionID) > 0, shouldEncrypt: c.isHandshakeCompletedSuccessfully(), }, }) } func (c *Conn) setHandshakeCompletedSuccessfully() bool { return c.handshakeCompletedSuccessfully.CompareAndSwap(false, true) } func (c *Conn) isHandshakeCompletedSuccessfully() bool { return c.handshakeCompletedSuccessfully.Load() } //nolint:cyclop,gocognit,contextcheck func (c *Conn) handshake( ctx context.Context, cfg *handshakeConfig, initialFlight flightVal, initialState handshakeState, ) error { c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight) done := make(chan struct{}) ctxRead, cancelRead := context.WithCancel(context.Background()) cfg.onFlightState = func(_ flightVal, s handshakeState) { if s == handshakeFinished && c.setHandshakeCompletedSuccessfully() { close(done) } } ctxHs, cancel := context.WithCancel(context.Background()) c.closeLock.Lock() c.cancelHandshaker = cancel c.cancelHandshakeReader = cancelRead c.closeLock.Unlock() firstErr := make(chan error, 1) var handshakeLoopsFinished sync.WaitGroup handshakeLoopsFinished.Add(2) // Handshake routine should be live until close. // The other party may request retransmission of the last flight to cope with packet drop. go func() { defer handshakeLoopsFinished.Done() err := c.fsm.Run(ctxHs, c, initialState) if !errors.Is(err, context.Canceled) { select { case firstErr <- err: default: } } }() go func() { defer func() { if c.isHandshakeCompletedSuccessfully() { // Escaping read loop. // It's safe to close decrypted channnel now. close(c.decrypted) } // Force stop handshaker when the underlying connection is closed. cancel() }() defer handshakeLoopsFinished.Done() for { if err := c.readAndBuffer(ctxRead); err != nil { //nolint:nestif var alertErr *alertError if errors.As(err, &alertErr) { if !alertErr.IsFatalOrCloseNotify() { if c.isHandshakeCompletedSuccessfully() { // Pass the error to Read() select { case c.decrypted <- err: case <-c.closed.Done(): case <-ctxRead.Done(): } } continue // non-fatal alert must not stop read loop } } else { switch { case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled), errors.Is(err, io.EOF), errors.Is(err, net.ErrClosed): case errors.Is(err, recordlayer.ErrInvalidPacketLength): // Decode error must be silently discarded // [RFC6347 Section-4.1.2.7] continue default: if c.isHandshakeCompletedSuccessfully() { // Keep read loop and pass the read error to Read() select { case c.decrypted <- err: case <-c.closed.Done(): case <-ctxRead.Done(): } continue // non-fatal alert must not stop read loop } } } select { case firstErr <- err: default: } if alertErr != nil { if alertErr.IsFatalOrCloseNotify() { _ = c.close(false) //nolint:contextcheck } } if !c.isConnectionClosed() && errors.Is(err, context.Canceled) { c.log.Trace("handshake timeouts - closing underline connection") _ = c.close(false) //nolint:contextcheck } return } } }() select { case err := <-firstErr: cancelRead() cancel() handshakeLoopsFinished.Wait() return c.translateHandshakeCtxError(err) case <-ctx.Done(): cancelRead() cancel() handshakeLoopsFinished.Wait() return c.translateHandshakeCtxError(ctx.Err()) case <-done: return nil } } func (c *Conn) translateHandshakeCtxError(err error) error { if err == nil { return nil } if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() { return nil } return &HandshakeError{Err: err} } func (c *Conn) close(byUser bool) error { c.closeLock.Lock() cancelHandshaker := c.cancelHandshaker cancelHandshakeReader := c.cancelHandshakeReader c.closeLock.Unlock() cancelHandshaker() cancelHandshakeReader() if c.isHandshakeCompletedSuccessfully() && byUser { // Discard error from notify() to return non-error on the first user call of Close() // even if the underlying connection is already closed. _ = c.notify(context.Background(), alert.Warning, alert.CloseNotify) } c.closeLock.Lock() // Don't return ErrConnClosed at the first time of the call from user. closedByUser := c.connectionClosedByUser if byUser { c.connectionClosedByUser = true } isClosed := c.isConnectionClosed() c.closed.Close() c.closeLock.Unlock() if closedByUser { return ErrConnClosed } if isClosed { return nil } return c.nextConn.Close() } func (c *Conn) isConnectionClosed() bool { select { case <-c.closed.Done(): return true default: return false } } func (c *Conn) setLocalEpoch(epoch uint16) { c.state.localEpoch.Store(epoch) } func (c *Conn) setRemoteEpoch(epoch uint16) { c.state.remoteEpoch.Store(epoch) } // LocalAddr implements net.Conn.LocalAddr. func (c *Conn) LocalAddr() net.Addr { return c.nextConn.LocalAddr() } // RemoteAddr implements net.Conn.RemoteAddr. func (c *Conn) RemoteAddr() net.Addr { c.lock.RLock() defer c.lock.RUnlock() return c.rAddr } func (c *Conn) sessionKey() []byte { if c.state.isClient { // As ServerName can be like 0.example.com, it's better to add // delimiter character which is not allowed to be in // neither address or domain name. return []byte(c.rAddr.String() + "_" + c.fsm.cfg.serverName) } return c.state.SessionID } // SetDeadline implements net.Conn.SetDeadline. func (c *Conn) SetDeadline(t time.Time) error { c.readDeadline.Set(t) return c.SetWriteDeadline(t) } // SetReadDeadline implements net.Conn.SetReadDeadline. func (c *Conn) SetReadDeadline(t time.Time) error { c.readDeadline.Set(t) // Read deadline is fully managed by this layer. // Don't set read deadline to underlying connection. return nil } // SetWriteDeadline implements net.Conn.SetWriteDeadline. func (c *Conn) SetWriteDeadline(t time.Time) error { c.writeDeadline.Set(t) // Write deadline is also fully managed by this layer. return nil } golang-github-pion-dtls-v3-3.0.7/conn_go_test.go000066400000000000000000000066311507057460300215320ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package dtls import ( "context" "crypto/tls" "errors" "net" "testing" "time" "github.com/pion/dtls/v3/pkg/crypto/selfsign" dtlsnet "github.com/pion/dtls/v3/pkg/net" "github.com/pion/transport/v3/dpipe" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/assert" ) func TestContextConfig(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() report := test.CheckRoutines(t) defer report() addrListen, err := net.ResolveUDPAddr("udp", "localhost:0") assert.NoError(t, err) // Dummy listener listen, err := net.ListenUDP("udp", addrListen) assert.NoError(t, err) defer func() { _ = listen.Close() }() addr, ok := listen.LocalAddr().(*net.UDPAddr) assert.True(t, ok) cert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) config := &Config{ Certificates: []tls.Certificate{cert}, } dials := map[string]struct { f func() (func() (net.Conn, error), func()) order []byte }{ "Dial": { f: func() (func() (net.Conn, error), func()) { ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) return func() (net.Conn, error) { conn, err := Dial("udp", addr, config) if err != nil { return nil, err } return conn, conn.HandshakeContext(ctx) }, func() { cancel() } }, order: []byte{0, 1, 2}, }, "Client": { f: func() (func() (net.Conn, error), func()) { ca, _ := dpipe.Pipe() ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) return func() (net.Conn, error) { conn, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) if err != nil { return nil, err } return conn, conn.HandshakeContext(ctx) }, func() { _ = ca.Close() cancel() } }, order: []byte{0, 1, 2}, }, "Server": { f: func() (func() (net.Conn, error), func()) { ca, _ := dpipe.Pipe() ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) return func() (net.Conn, error) { conn, err := Server(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config) if err != nil { return nil, err } return conn, conn.HandshakeContext(ctx) }, func() { _ = ca.Close() cancel() } }, order: []byte{0, 1, 2}, }, } for name, dial := range dials { dial := dial t.Run(name, func(t *testing.T) { done := make(chan struct{}) go func() { d, cancel := dial.f() conn, err := d() defer cancel() var netError net.Error if !errors.As(err, &netError) || !netError.Temporary() { //nolint:staticcheck assert.Fail(t, "Dial failed with unexpected error", "err: %v", err) close(done) return } done <- struct{}{} if err == nil { _ = conn.Close() } }() var order []byte early := time.After(20 * time.Millisecond) late := time.After(60 * time.Millisecond) func() { for len(order) < 3 { select { case <-early: order = append(order, 0) case _, ok := <-done: if !ok { return } order = append(order, 1) case <-late: order = append(order, 2) } } }() assert.Equal(t, dial.order, order, "Invalid cancel timing") }) } } golang-github-pion-dtls-v3-3.0.7/conn_test.go000066400000000000000000002734151507057460300210530ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "bytes" "context" "crypto" "crypto/ecdsa" cryptoElliptic "crypto/elliptic" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "encoding/hex" "errors" "fmt" "io" "net" "strings" "sync" "sync/atomic" "testing" "time" "github.com/pion/dtls/v3/internal/ciphersuite" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/crypto/hash" "github.com/pion/dtls/v3/pkg/crypto/selfsign" "github.com/pion/dtls/v3/pkg/crypto/signature" "github.com/pion/dtls/v3/pkg/crypto/signaturehash" dtlsnet "github.com/pion/dtls/v3/pkg/net" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" "github.com/pion/logging" "github.com/pion/transport/v3/dpipe" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/assert" ) var ( errTestPSKInvalidIdentity = errors.New("TestPSK: Server got invalid identity") errPSKRejected = errors.New("PSK Rejected") errNotExpectedChain = errors.New("not expected chain") errExpecedChain = errors.New("expected chain") errWrongCert = errors.New("wrong cert") ) func TestStressDuplex(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() // Run the test stressDuplex(t) } func stressDuplex(t *testing.T) { t.Helper() ca, cb, err := pipeMemory() assert.NoError(t, err) defer func() { assert.NoError(t, ca.Close()) assert.NoError(t, cb.Close()) }() opt := test.Options{ MsgSize: 2048, MsgCount: 100, } assert.NoError(t, test.StressDuplex(ca, cb, opt)) } func TestRoutineLeakOnClose(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() ca, cb, err := pipeMemory() assert.NoError(t, err) _, err = ca.Write(make([]byte, 100)) assert.NoError(t, err) assert.NoError(t, cb.Close()) assert.NoError(t, ca.Close()) // Packet is sent, but not read. // inboundLoop routine should not be leaked. } func TestReadWriteDeadline(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() var netErr net.Error ca, cb, err := pipeMemory() assert.NoError(t, err) assert.NoError(t, ca.SetDeadline(time.Unix(0, 1))) _, werr := ca.Write(make([]byte, 100)) assert.ErrorAs(t, werr, &netErr, "Write must return net.Error") assert.True(t, netErr.Timeout(), "Deadline exceeded Write must return Timeout") assert.True(t, netErr.Temporary(), "Deadline exceeded Write must return Temporary") //nolint:staticcheck _, rerr := ca.Read(make([]byte, 100)) assert.ErrorAs(t, rerr, &netErr, "Read must return net.Error") assert.True(t, netErr.Timeout(), "Deadline exceeded Read must return Timeout") assert.True(t, netErr.Temporary(), "Deadline exceeded Read must return Temporary") //nolint:staticcheck assert.NoError(t, ca.SetDeadline(time.Time{})) assert.NoError(t, ca.Close()) assert.NoError(t, cb.Close()) _, err = ca.Write(make([]byte, 100)) assert.ErrorIs(t, err, ErrConnClosed) _, err = ca.Read(make([]byte, 100)) assert.ErrorIs(t, err, io.EOF) } func TestSequenceNumberOverflow(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() t.Run("ApplicationData", func(t *testing.T) { ca, cb, err := pipeMemory() assert.NoError(t, err) atomic.StoreUint64(&ca.state.localSequenceNumber[1], recordlayer.MaxSequenceNumber) _, werr := ca.Write(make([]byte, 100)) assert.NoError(t, werr, "Write must send message with maximum sequence number") _, werr = ca.Write(make([]byte, 100)) assert.ErrorIs(t, werr, errSequenceNumberOverflow, "Write must abandonsend message with maximum sequence number") assert.NoError(t, ca.Close()) assert.NoError(t, cb.Close()) }) t.Run("Handshake", func(t *testing.T) { ca, cb, err := pipeMemory() assert.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() atomic.StoreUint64(&ca.state.localSequenceNumber[0], recordlayer.MaxSequenceNumber+1) // Try to send handshake packet. werr := ca.writePackets(ctx, []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageClientHello{ Version: protocol.Version1_2, Cookie: make([]byte, 64), CipherSuiteIDs: cipherSuiteIDs(defaultCipherSuites()), CompressionMethods: defaultCompressionMethods(), }, }, }, }, }) assert.ErrorIs(t, werr, errSequenceNumberOverflow, "Connection must fail when handshake packet reaches maximum sequence num") assert.NoError(t, ca.Close()) assert.NoError(t, cb.Close()) }) } func pipeMemory() (*Conn, *Conn, error) { // In memory pipe ca, cb := dpipe.Pipe() return pipeConn(ca, cb) } func pipeConn(ca, cb net.Conn) (*Conn, *Conn, error) { type result struct { c *Conn err error } resultCh := make(chan result) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() // Setup client go func() { client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, }, true) resultCh <- result{client, err} }() // Setup server server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, }, true) if err != nil { return nil, nil, err } // Receive client res := <-resultCh if res.err != nil { _ = server.Close() return nil, nil, res.err } return res.c, server, nil } func testClient( ctx context.Context, pktConn net.PacketConn, rAddr net.Addr, cfg *Config, generateCertificate bool, ) (*Conn, error) { if generateCertificate { clientCert, err := selfsign.GenerateSelfSigned() if err != nil { return nil, err } cfg.Certificates = []tls.Certificate{clientCert} } cfg.InsecureSkipVerify = true conn, err := Client(pktConn, rAddr, cfg) if err != nil { return nil, err } return conn, conn.HandshakeContext(ctx) } func testServer( ctx context.Context, c net.PacketConn, rAddr net.Addr, cfg *Config, generateCertificate bool, ) (*Conn, error) { if generateCertificate { serverCert, err := selfsign.GenerateSelfSigned() if err != nil { return nil, err } cfg.Certificates = []tls.Certificate{serverCert} } conn, err := Server(c, rAddr, cfg) if err != nil { return nil, err } return conn, conn.HandshakeContext(ctx) } func sendClientHello(cookie []byte, ca net.Conn, sequenceNumber uint64, extensions []extension.Extension) error { packet, err := (&recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, SequenceNumber: sequenceNumber, }, Content: &handshake.Handshake{ Header: handshake.Header{ MessageSequence: uint16(sequenceNumber), //nolint:gosec // G115 }, Message: &handshake.MessageClientHello{ Version: protocol.Version1_2, Cookie: cookie, CipherSuiteIDs: cipherSuiteIDs(defaultCipherSuites()), CompressionMethods: defaultCompressionMethods(), Extensions: extensions, }, }, }).Marshal() if err != nil { return err } if _, err = ca.Write(packet); err != nil { return err } return nil } func TestHandshakeWithAlert(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() cases := map[string]struct { configServer, configClient *Config errServer, errClient error }{ "CipherSuiteNoIntersection": { configServer: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, }, configClient: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, }, errServer: errCipherSuiteNoIntersection, errClient: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, }, "SignatureSchemesNoIntersection": { configServer: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, SignatureSchemes: []tls.SignatureScheme{tls.ECDSAWithP256AndSHA256}, }, configClient: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, SignatureSchemes: []tls.SignatureScheme{tls.ECDSAWithP521AndSHA512}, }, errServer: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, errClient: errNoAvailableSignatureSchemes, }, } for name, testCase := range cases { testCase := testCase t.Run(name, func(t *testing.T) { clientErr := make(chan error, 1) ca, cb := dpipe.Pipe() go func() { _, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), testCase.configClient, true) clientErr <- err }() _, errServer := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), testCase.configServer, true) assert.ErrorIs(t, errServer, testCase.errServer) assert.ErrorIs(t, <-clientErr, testCase.errClient) }) } } func TestHandshakeWithInvalidRecord(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() type result struct { c *Conn err error } clientErr := make(chan result, 1) ca, cb := dpipe.Pipe() caWithInvalidRecord := &connWithCallback{Conn: ca} var msgSeq atomic.Int32 // Send invalid record after first message caWithInvalidRecord.onWrite = func([]byte) { if msgSeq.Add(1) == 2 { _, err := ca.Write([]byte{0x01, 0x02}) assert.NoError(t, err) } } go func() { client, err := testClient( ctx, dtlsnet.PacketConnFromConn(caWithInvalidRecord), caWithInvalidRecord.RemoteAddr(), &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}}, true, ) clientErr <- result{client, err} }() server, errServer := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, }, true) errClient := <-clientErr defer func() { if server != nil { assert.NoError(t, server.Close()) } if errClient.c != nil { assert.NoError(t, errClient.c.Close()) } }() assert.NoError(t, errServer) assert.NoError(t, errClient.err) } func TestExportKeyingMaterial(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() var rand [28]byte exportLabel := "EXTRACTOR-dtls_srtp" expectedServerKey := []byte{0x61, 0x09, 0x9d, 0x7d, 0xcb, 0x08, 0x52, 0x2c, 0xe7, 0x7b} expectedClientKey := []byte{0x87, 0xf0, 0x40, 0x02, 0xf6, 0x1c, 0xf1, 0xfe, 0x8c, 0x77} conn := &Conn{ state: State{ localRandom: handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand}, remoteRandom: handshake.Random{GMTUnixTime: time.Unix(1000, 0), RandomBytes: rand}, localSequenceNumber: []uint64{0, 0}, cipherSuite: &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}, }, } conn.setLocalEpoch(0) conn.setRemoteEpoch(0) state, ok := conn.ConnectionState() assert.True(t, ok) _, err := state.ExportKeyingMaterial(exportLabel, nil, 0) assert.ErrorIs(t, err, errHandshakeInProgress, "ExportKeyingMaterial when epoch == 0 error mismatch") conn.setLocalEpoch(1) state, ok = conn.ConnectionState() assert.True(t, ok) _, err = state.ExportKeyingMaterial(exportLabel, []byte{0x00}, 0) assert.ErrorIs(t, err, errContextUnsupported, "ExportKeyingMaterial with context mismatch") for k := range invalidKeyingLabels() { state, ok = conn.ConnectionState() assert.True(t, ok) _, err = state.ExportKeyingMaterial(k, nil, 0) assert.ErrorIs(t, err, errReservedExportKeyingMaterial, "ExportKeyingMaterial reserved label mismatch") } state, ok = conn.ConnectionState() assert.True(t, ok) keyingMaterial, err := state.ExportKeyingMaterial(exportLabel, nil, 10) assert.NoError(t, err, "ExportingKeyingMaterial as server error") assert.Equal(t, expectedServerKey, keyingMaterial, "ExportKeyingMaterial client export mismatch") conn.state.isClient = true state, ok = conn.ConnectionState() assert.True(t, ok) keyingMaterial, err = state.ExportKeyingMaterial(exportLabel, nil, 10) assert.NoError(t, err) assert.Equal(t, expectedClientKey, keyingMaterial, "ExportKeyingMaterial client report mismatch") } func TestPSK(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ClientIdentity []byte ServerIdentity []byte CipherSuites []CipherSuiteID ClientVerifyConnection func(*State) error ServerVerifyConnection func(*State) error WantFail bool ExpectedServerErr string ExpectedClientErr string }{ { Name: "Server identity specified", ServerIdentity: []byte("Test Identity"), ClientIdentity: []byte("Client Identity"), CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, }, { Name: "Server identity specified - Server verify connection fails", ServerIdentity: []byte("Test Identity"), ClientIdentity: []byte("Client Identity"), CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, ServerVerifyConnection: func(*State) error { return errExample }, WantFail: true, ExpectedServerErr: errExample.Error(), ExpectedClientErr: alert.BadCertificate.String(), }, { Name: "Server identity specified - Client verify connection fails", ServerIdentity: []byte("Test Identity"), ClientIdentity: []byte("Client Identity"), CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, ClientVerifyConnection: func(*State) error { return errExample }, WantFail: true, ExpectedServerErr: alert.BadCertificate.String(), ExpectedClientErr: errExample.Error(), }, { Name: "Server identity nil", ServerIdentity: nil, ClientIdentity: []byte("Client Identity"), CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, }, { Name: "TLS_PSK_WITH_AES_128_CBC_SHA256", ServerIdentity: nil, ClientIdentity: []byte("Client Identity"), CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CBC_SHA256}, }, { Name: "TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256", ServerIdentity: nil, ClientIdentity: []byte("Client Identity"), CipherSuites: []CipherSuiteID{TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256}, }, { Name: "Client identity empty", ServerIdentity: nil, ClientIdentity: []byte{}, CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, }, } { test := test t.Run(test.Name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() type result struct { c *Conn err error } clientRes := make(chan result, 1) ca, cb := dpipe.Pipe() go func() { conf := &Config{ PSK: func(hint []byte) ([]byte, error) { if !bytes.Equal(test.ServerIdentity, hint) { return nil, fmt.Errorf( //nolint:goerr113 "TestPSK: Client got invalid identity expected(% 02x) actual(% 02x)", test.ServerIdentity, hint, ) } return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: test.ClientIdentity, CipherSuites: test.CipherSuites, VerifyConnection: test.ClientVerifyConnection, } c, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, false) clientRes <- result{c, err} }() config := &Config{ PSK: func(hint []byte) ([]byte, error) { t.Log(hint) if !bytes.Equal(test.ClientIdentity, hint) { return nil, fmt.Errorf("%w: expected(% 02x) actual(% 02x)", errTestPSKInvalidIdentity, test.ClientIdentity, hint) } return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: test.ServerIdentity, CipherSuites: test.CipherSuites, VerifyConnection: test.ServerVerifyConnection, } server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, false) if test.WantFail { res := <-clientRes assert.Error(t, err) assert.True(t, strings.Contains(err.Error(), test.ExpectedServerErr), "TestPSK: Server expected error mismatch") assert.Error(t, res.err, "TestPSK: Client expected error mismatch") assert.True(t, strings.Contains(res.err.Error(), test.ExpectedClientErr), "TestPSK: Client expeected error mismatch") return } assert.NoError(t, err) state, ok := server.ConnectionState() assert.True(t, ok, "TestPSK: Server ConnectionState failed") actualPSKIdentityHint := state.IdentityHint assert.Equal(t, test.ClientIdentity, actualPSKIdentityHint, "TestPSK: Server ClientPSKIdentity Mismatch") defer func() { _ = server.Close() }() res := <-clientRes assert.NoError(t, res.err) assert.NoError(t, res.c.Close()) }) } } func TestPSKHintFail(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() serverAlertError := &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InternalError}} pskRejected := errPSKRejected // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() clientErr := make(chan error, 1) ca, cb := dpipe.Pipe() go func() { conf := &Config{ PSK: func([]byte) ([]byte, error) { return nil, pskRejected }, PSKIdentityHint: []byte{}, CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, } _, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, false) clientErr <- err }() config := &Config{ PSK: func([]byte) ([]byte, error) { return nil, pskRejected }, PSKIdentityHint: []byte{}, CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, } _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, false) assert.ErrorIs(t, err, serverAlertError, "TestPSK: Server should fail with alert error") assert.ErrorIs(t, <-clientErr, pskRejected, "TestPSK: Client should fail with pskRejected error") } // Assert that ServerKeyExchange is only sent if Identity is set on server side. func TestPSKServerKeyExchange(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string SetIdentity bool }{ { Name: "Server Identity Set", SetIdentity: true, }, { Name: "Server Not Identity Set", SetIdentity: false, }, } { test := test t.Run(test.Name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() gotServerKeyExchange := false clientErr := make(chan error, 1) ca, cb := dpipe.Pipe() cbAnalyzer := &connWithCallback{Conn: cb} cbAnalyzer.onWrite = func(in []byte) { messages, err := recordlayer.UnpackDatagram(in) assert.NoError(t, err) for i := range messages { h := &handshake.Handshake{} _ = h.Unmarshal(messages[i][recordlayer.FixedHeaderSize:]) if h.Header.Type == handshake.TypeServerKeyExchange { gotServerKeyExchange = true } } } go func() { conf := &Config{ PSK: func([]byte) ([]byte, error) { return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: []byte{0xAB, 0xC1, 0x23}, CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, } if client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, false); err != nil { clientErr <- err } else { clientErr <- client.Close() //nolint } }() config := &Config{ PSK: func([]byte) ([]byte, error) { return []byte{0xAB, 0xC1, 0x23}, nil }, CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, } if test.SetIdentity { config.PSKIdentityHint = []byte{0xAB, 0xC1, 0x23} } server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cbAnalyzer), cbAnalyzer.RemoteAddr(), config, false) assert.NoError(t, err) assert.NoError(t, server.Close()) assert.NoError(t, <-clientErr, "TestPSK: Client erro") assert.Equal(t, test.SetIdentity, gotServerKeyExchange) }) } } func TestClientTimeout(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() clientErr := make(chan error, 1) ca, _ := dpipe.Pipe() go func() { conf := &Config{} c, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, true) if err == nil { _ = c.Close() //nolint:contextcheck } clientErr <- err }() // no server! err := <-clientErr var netErr net.Error assert.ErrorAs(t, err, &netErr, "Client error exp(Temporary network error) failed") assert.True(t, netErr.Timeout(), "Client error exp(Timeout) failed") } func TestSRTPConfiguration(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ClientSRTP []SRTPProtectionProfile ServerSRTP []SRTPProtectionProfile ClientSRTPMasterKeyIdentifier []byte ServerSRTPMasterKeyIdentifier []byte ExpectedProfile SRTPProtectionProfile WantClientError error WantServerError error }{ { Name: "No SRTP in use", ClientSRTP: nil, ServerSRTP: nil, ExpectedProfile: 0, WantClientError: nil, WantServerError: nil, }, { Name: "SRTP both ends", ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80, ClientSRTPMasterKeyIdentifier: []byte("ClientSRTPMKI"), ServerSRTPMasterKeyIdentifier: []byte("ServerSRTPMKI"), WantClientError: nil, WantServerError: nil, }, { Name: "SRTP client only", ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, ServerSRTP: nil, ExpectedProfile: 0, WantClientError: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, WantServerError: errServerNoMatchingSRTPProfile, }, { Name: "SRTP server only", ClientSRTP: nil, ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, ExpectedProfile: 0, WantClientError: nil, WantServerError: nil, }, { Name: "Multiple Suites", ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32}, ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32}, ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80, WantClientError: nil, WantServerError: nil, }, { Name: "Multiple Suites, Client Chooses", ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32}, ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_32, SRTP_AES128_CM_HMAC_SHA1_80}, ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80, WantClientError: nil, WantServerError: nil, }, } { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() type result struct { c *Conn err error } resultCh := make(chan result) go func() { client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ SRTPProtectionProfiles: test.ClientSRTP, SRTPMasterKeyIdentifier: test.ServerSRTPMasterKeyIdentifier, }, true) resultCh <- result{client, err} }() server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ SRTPProtectionProfiles: test.ServerSRTP, SRTPMasterKeyIdentifier: test.ClientSRTPMasterKeyIdentifier, }, true) assert.ErrorIs(t, err, test.WantServerError, "TestSRTPConfiguration: Server Error Mismatch") if err == nil { defer func() { _ = server.Close() }() } res := <-resultCh if res.err == nil { defer func() { _ = res.c.Close() }() } assert.ErrorIsf(t, res.err, test.WantClientError, "TestSRTPConfiguration: Client Error Mismatch '%s'", test.Name) if res.c == nil { return } actualClientSRTP, _ := res.c.SelectedSRTPProtectionProfile() assert.Equalf(t, test.ExpectedProfile, actualClientSRTP, "TestSRTPConfiguration: Client SRTPProtectionProfile Mismatch '%s'", test.Name) actualServerSRTP, _ := server.SelectedSRTPProtectionProfile() assert.Equalf(t, test.ExpectedProfile, actualServerSRTP, "TestSRTPConfiguration: Server SRTPProtectionProfile Mismatch '%s'", test.Name) actualServerMKI, _ := server.RemoteSRTPMasterKeyIdentifier() assert.Truef(t, bytes.Equal(test.ServerSRTPMasterKeyIdentifier, actualServerMKI), "TestSRTPConfiguration: Server SRTPMKI Mismatch '%s'", test.Name) actualClientMKI, _ := res.c.RemoteSRTPMasterKeyIdentifier() assert.Truef(t, bytes.Equal(test.ClientSRTPMasterKeyIdentifier, actualClientMKI), "TestSRTPConfiguration: Client SRTPMKI Mismatch '%s'", test.Name) } } func TestClientCertificate(t *testing.T) { //nolint:gocyclo,cyclop,maintidx // Check for leaking routines report := test.CheckRoutines(t) defer report() srvCert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) srvCAPool := x509.NewCertPool() srvCertificate, err := x509.ParseCertificate(srvCert.Certificate[0]) assert.NoError(t, err) srvCAPool.AddCert(srvCertificate) cert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) certificate, err := x509.ParseCertificate(cert.Certificate[0]) assert.NoError(t, err) caPool := x509.NewCertPool() caPool.AddCert(certificate) t.Run("parallel", func(t *testing.T) { // sync routines to check routine leak tests := map[string]struct { clientCfg *Config serverCfg *Config wantErr bool }{ "NoClientCert": { clientCfg: &Config{RootCAs: srvCAPool}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: NoClientCert, ClientCAs: caPool, }, }, "NoClientCert_ServerVerifyConnectionFails": { clientCfg: &Config{RootCAs: srvCAPool}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: NoClientCert, ClientCAs: caPool, VerifyConnection: func(*State) error { return errExample }, }, wantErr: true, }, "NoClientCert_ClientVerifyConnectionFails": { clientCfg: &Config{RootCAs: srvCAPool, VerifyConnection: func(*State) error { return errExample }}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: NoClientCert, ClientCAs: caPool, }, wantErr: true, }, "NoClientCert_cert": { clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: RequireAnyClientCert, }, }, "RequestClientCert_cert_sigscheme": { // specify signature algorithm clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ SignatureSchemes: []tls.SignatureScheme{tls.ECDSAWithP521AndSHA512}, Certificates: []tls.Certificate{srvCert}, ClientAuth: RequestClientCert, }, }, "RequestClientCert_cert": { clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: RequestClientCert, }, }, "RequestClientCert_no_cert": { clientCfg: &Config{RootCAs: srvCAPool}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: RequestClientCert, ClientCAs: caPool, }, }, "RequireAnyClientCert": { clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: RequireAnyClientCert, }, }, "RequireAnyClientCert_error": { clientCfg: &Config{RootCAs: srvCAPool}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: RequireAnyClientCert, }, wantErr: true, }, "VerifyClientCertIfGiven_no_cert": { clientCfg: &Config{RootCAs: srvCAPool}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: VerifyClientCertIfGiven, ClientCAs: caPool, }, }, "VerifyClientCertIfGiven_cert": { clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: VerifyClientCertIfGiven, ClientCAs: caPool, }, }, "VerifyClientCertIfGiven_error": { clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: VerifyClientCertIfGiven, }, wantErr: true, }, "RequireAndVerifyClientCert": { clientCfg: &Config{ RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}, VerifyConnection: func(s *State) error { if ok := bytes.Equal(s.PeerCertificates[0], srvCertificate.Raw); !ok { return errExample } return nil }, }, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: RequireAndVerifyClientCert, ClientCAs: caPool, VerifyConnection: func(s *State) error { if ok := bytes.Equal(s.PeerCertificates[0], certificate.Raw); !ok { return errExample } return nil }, }, }, "RequireAndVerifyClientCert_callbacks": { clientCfg: &Config{ RootCAs: srvCAPool, // Certificates: []tls.Certificate{cert}, GetClientCertificate: func(*CertificateRequestInfo) (*tls.Certificate, error) { return &cert, nil }, }, serverCfg: &Config{ GetCertificate: func(*ClientHelloInfo) (*tls.Certificate, error) { return &srvCert, nil }, // Certificates: []tls.Certificate{srvCert}, ClientAuth: RequireAndVerifyClientCert, ClientCAs: caPool, }, }, } for name, tt := range tests { tt := tt t.Run(name, func(t *testing.T) { ca, cb := dpipe.Pipe() type result struct { c *Conn err, hserr error } c := make(chan result) go func() { client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), tt.clientCfg) c <- result{client, err, client.Handshake()} }() server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg) hserr := server.Handshake() res := <-c defer func() { if err == nil { _ = server.Close() } if res.err == nil { _ = res.c.Close() } }() if tt.wantErr { assert.True(t, err != nil || hserr != nil, "Error expected") return // Error expected, test succeeded } assert.NoError(t, err) assert.NoError(t, res.err) state, ok := server.ConnectionState() assert.True(t, ok, "Server connection state not available") actualClientCert := state.PeerCertificates //nolint:nestif if tt.serverCfg.ClientAuth == RequireAnyClientCert || tt.serverCfg.ClientAuth == RequireAndVerifyClientCert { assert.NotNil(t, actualClientCert, "Client did not provide a certificate") var cfgCert [][]byte if len(tt.clientCfg.Certificates) > 0 { cfgCert = tt.clientCfg.Certificates[0].Certificate } if tt.clientCfg.GetClientCertificate != nil { crt, err := tt.clientCfg.GetClientCertificate(&CertificateRequestInfo{}) assert.NoError(t, err, "Server configuration did not provide a certificate") cfgCert = crt.Certificate } assert.NotEmpty(t, cfgCert, "Client certificate was not communicated correctly") assert.Equal(t, actualClientCert[0], cfgCert[0], "Client certificate was not communicated correctly") } if tt.serverCfg.ClientAuth == NoClientCert { assert.Nil(t, actualClientCert, "Client certificate wasn't expected") } clientState, ok := res.c.ConnectionState() assert.True(t, ok, "Client connection state not available") actualServerCert := clientState.PeerCertificates assert.NotNil(t, actualServerCert, "server did not provide a certificate") var cfgCert [][]byte if len(tt.serverCfg.Certificates) > 0 { cfgCert = tt.serverCfg.Certificates[0].Certificate } if tt.serverCfg.GetCertificate != nil { crt, err := tt.serverCfg.GetCertificate(&ClientHelloInfo{}) assert.NoError(t, err, "Server configuration did not provide a certificate") cfgCert = crt.Certificate } assert.NotEmpty(t, cfgCert, "Server certificate was not communicated correctly") assert.Equal(t, actualServerCert[0], cfgCert[0], "Server certificate was not communicated correctly") }) } }) } func TestConnectionID(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() clientCID := []byte{5, 77, 33, 24, 93, 27, 45, 81} serverCID := []byte{64, 24, 73, 2, 17, 96, 38, 59} cidEcho := func(echo []byte) func() []byte { return func() []byte { return echo } } tests := map[string]struct { clientCfg *Config serverCfg *Config clientConnectionID []byte serverConnectionID []byte }{ "BidirectionalConnectionIDs": { clientCfg: &Config{ ConnectionIDGenerator: cidEcho(clientCID), }, serverCfg: &Config{ ConnectionIDGenerator: cidEcho(serverCID), }, clientConnectionID: clientCID, serverConnectionID: serverCID, }, "BothSupportOnlyClientSends": { clientCfg: &Config{ ConnectionIDGenerator: cidEcho(nil), }, serverCfg: &Config{ ConnectionIDGenerator: cidEcho(serverCID), }, serverConnectionID: serverCID, }, "BothSupportOnlyServerSends": { clientCfg: &Config{ ConnectionIDGenerator: cidEcho(clientCID), }, serverCfg: &Config{ ConnectionIDGenerator: cidEcho(nil), }, clientConnectionID: clientCID, }, "ClientDoesNotSupport": { clientCfg: &Config{}, serverCfg: &Config{ ConnectionIDGenerator: cidEcho(serverCID), }, }, "ServerDoesNotSupport": { clientCfg: &Config{ ConnectionIDGenerator: cidEcho(clientCID), }, serverCfg: &Config{}, }, "NeitherSupport": { clientCfg: &Config{}, serverCfg: &Config{}, }, } for name, tt := range tests { tt := tt t.Run(name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() type result struct { c *Conn err error } c := make(chan result) go func() { client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), tt.clientCfg, true) c <- result{client, err} }() server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg, true) assert.NoError(t, err) res := <-c assert.NoError(t, res.err) defer func() { if err == nil { _ = server.Close() } if res.err == nil { _ = res.c.Close() } }() assert.True(t, bytes.Equal(tt.clientConnectionID, res.c.state.getLocalConnectionID()), "Unexpected client local connection ID") assert.True(t, bytes.Equal(tt.serverConnectionID, res.c.state.remoteConnectionID), "Unexpected client remote connection ID") assert.True(t, bytes.Equal(tt.serverConnectionID, server.state.getLocalConnectionID()), "Unexpected server local connection ID") assert.True(t, bytes.Equal(tt.clientConnectionID, server.state.remoteConnectionID), "Unexpected server remote connection ID") }) } } func TestExtendedMasterSecret(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() tests := map[string]struct { clientCfg *Config serverCfg *Config expectedClientErr error expectedServerErr error }{ "Request_Request_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: RequestExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: RequestExtendedMasterSecret, }, expectedClientErr: nil, expectedServerErr: nil, }, "Request_Require_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: RequestExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: RequireExtendedMasterSecret, }, expectedClientErr: nil, expectedServerErr: nil, }, "Request_Disable_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: RequestExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: DisableExtendedMasterSecret, }, expectedClientErr: nil, expectedServerErr: nil, }, "Require_Request_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: RequireExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: RequestExtendedMasterSecret, }, expectedClientErr: nil, expectedServerErr: nil, }, "Require_Require_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: RequireExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: RequireExtendedMasterSecret, }, expectedClientErr: nil, expectedServerErr: nil, }, "Require_Disable_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: RequireExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: DisableExtendedMasterSecret, }, expectedClientErr: errClientRequiredButNoServerEMS, expectedServerErr: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, }, "Disable_Request_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: DisableExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: RequestExtendedMasterSecret, }, expectedClientErr: nil, expectedServerErr: nil, }, "Disable_Require_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: DisableExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: RequireExtendedMasterSecret, }, expectedClientErr: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, expectedServerErr: errServerRequiredButNoClientEMS, }, "Disable_Disable_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: DisableExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: DisableExtendedMasterSecret, }, expectedClientErr: nil, expectedServerErr: nil, }, } for name, tt := range tests { tt := tt t.Run(name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() type result struct { c *Conn err error } c := make(chan result) go func() { client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), tt.clientCfg, true) c <- result{client, err} }() server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg, true) res := <-c defer func() { if err == nil { _ = server.Close() } if res.err == nil { _ = res.c.Close() } }() assert.ErrorIs(t, res.err, tt.expectedClientErr) assert.ErrorIs(t, err, tt.expectedServerErr) }) } } func TestServerCertificate(t *testing.T) { //nolint:cyclop // Check for leaking routines report := test.CheckRoutines(t) defer report() cert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) certificate, err := x509.ParseCertificate(cert.Certificate[0]) assert.NoError(t, err) caPool := x509.NewCertPool() caPool.AddCert(certificate) t.Run("parallel", func(t *testing.T) { // sync routines to check routine leak tests := map[string]struct { clientCfg *Config serverCfg *Config wantErr bool }{ "no_ca": { clientCfg: &Config{}, serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, wantErr: true, }, "good_ca": { clientCfg: &Config{RootCAs: caPool}, serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, }, "no_ca_skip_verify": { clientCfg: &Config{InsecureSkipVerify: true}, serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, }, "good_ca_skip_verify_custom_verify_peer": { clientCfg: &Config{RootCAs: caPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ Certificates: []tls.Certificate{cert}, ClientAuth: RequireAnyClientCert, VerifyPeerCertificate: func(_ [][]byte, chain [][]*x509.Certificate) error { if len(chain) != 0 { return errNotExpectedChain } return nil }, }, }, "good_ca_verify_custom_verify_peer": { clientCfg: &Config{RootCAs: caPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ ClientCAs: caPool, Certificates: []tls.Certificate{cert}, ClientAuth: RequireAndVerifyClientCert, VerifyPeerCertificate: func(_ [][]byte, chain [][]*x509.Certificate) error { if len(chain) == 0 { return errExpecedChain } return nil }, }, }, "good_ca_custom_verify_peer": { clientCfg: &Config{ RootCAs: caPool, VerifyPeerCertificate: func([][]byte, [][]*x509.Certificate) error { return errWrongCert }, }, serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, wantErr: true, }, "server_name": { clientCfg: &Config{RootCAs: caPool, ServerName: certificate.Subject.CommonName}, serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, }, "server_name_error": { clientCfg: &Config{RootCAs: caPool, ServerName: "barfoo"}, serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, wantErr: true, }, } for name, tt := range tests { tt := tt t.Run(name, func(t *testing.T) { ca, cb := dpipe.Pipe() type result struct { c *Conn err, hserr error } srvCh := make(chan result) go func() { s, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg) srvCh <- result{s, err, s.Handshake()} }() cli, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), tt.clientCfg) hserr := cli.Handshake() if err == nil { _ = cli.Close() } if tt.wantErr { assert.True(t, err != nil || hserr != nil, "Expected error") } else { assert.NoError(t, err, "Client connection failed") assert.NoError(t, hserr, "Client handshake failed") } srv := <-srvCh if srv.err == nil { _ = srv.c.Close() } }) } }) } func TestCipherSuiteConfiguration(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ClientCipherSuites []CipherSuiteID ServerCipherSuites []CipherSuiteID WantClientError error WantServerError error WantSelectedCipherSuite CipherSuiteID }{ { Name: "No CipherSuites specified", ClientCipherSuites: nil, ServerCipherSuites: nil, WantClientError: nil, WantServerError: nil, }, { Name: "Invalid CipherSuite", ClientCipherSuites: []CipherSuiteID{0x00}, ServerCipherSuites: []CipherSuiteID{0x00}, WantClientError: &invalidCipherSuiteError{0x00}, WantServerError: &invalidCipherSuiteError{0x00}, }, { Name: "Valid CipherSuites specified", ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, WantClientError: nil, WantServerError: nil, WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, }, { Name: "CipherSuites mismatch", ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, WantClientError: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, WantServerError: errCipherSuiteNoIntersection, }, { Name: "Valid CipherSuites CCM specified", ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM}, ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM}, WantClientError: nil, WantServerError: nil, WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_CCM, }, { Name: "Valid CipherSuites CCM-8 specified", ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8}, ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8}, WantClientError: nil, WantServerError: nil, WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8, }, { Name: "Server supports subset of client suites", ClientCipherSuites: []CipherSuiteID{ TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, }, ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, WantClientError: nil, WantServerError: nil, WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, }, } { test := test t.Run(test.Name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() type result struct { c *Conn err error } resultCh := make(chan result) go func() { client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ CipherSuites: test.ClientCipherSuites, }, true) resultCh <- result{client, err} }() server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ CipherSuites: test.ServerCipherSuites, }, true) if err == nil { defer func() { _ = server.Close() }() } assert.ErrorIsf(t, err, test.WantServerError, "TestCipherSuiteConfiguration: Server Error Mismatch '%s'", test.Name) res := <-resultCh if err == nil { assert.NoError(t, server.Close()) assert.NoError(t, res.c.Close()) } assert.ErrorIsf(t, res.err, test.WantClientError, "TestCipherSuiteConfiguration: Client Error Mismatch '%s'") if test.WantSelectedCipherSuite != 0x00 { assert.Equal(t, test.WantSelectedCipherSuite, res.c.state.cipherSuite.ID(), "TestCipherSuiteConfiguration: Server Selected Bad Cipher Suite '%s'", test.Name) } }) } } func TestCertificateAndPSKServer(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ClientPSK bool }{ { Name: "Client uses PKI", ClientPSK: false, }, { Name: "Client uses PSK", ClientPSK: true, }, } { test := test t.Run(test.Name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() type result struct { c *Conn err error } resultCh := make(chan result) go func() { config := &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}} if test.ClientPSK { config.PSK = func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil } config.PSKIdentityHint = []byte{0x00} config.CipherSuites = []CipherSuiteID{TLS_PSK_WITH_AES_128_GCM_SHA256} } client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config, false) resultCh <- result{client, err} }() config := &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_PSK_WITH_AES_128_GCM_SHA256}, PSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, } server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) assert.NoErrorf(t, err, "TestCertificateAndPSKServer: Server Error Mismatch '%s'", test.Name) if err != nil { defer func() { assert.NoError(t, server.Close()) }() } res := <-resultCh assert.NoErrorf(t, res.err, "TestCertificateAndPSKServer: Server Error Mismatch '%s'", test.Name) assert.NoError(t, server.Close()) assert.NoError(t, res.c.Close()) }) } } func TestPSKConfiguration(t *testing.T) { //nolint:cyclop // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ClientHasCertificate bool ServerHasCertificate bool ClientPSK PSKCallback ServerPSK PSKCallback ClientPSKIdentity []byte ServerPSKIdentity []byte WantClientError error WantServerError error }{ { Name: "PSK and no certificate specified", ClientHasCertificate: false, ServerHasCertificate: false, ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, ClientPSKIdentity: []byte{0x00}, ServerPSKIdentity: []byte{0x00}, WantClientError: errNoAvailablePSKCipherSuite, WantServerError: errNoAvailablePSKCipherSuite, }, { Name: "PSK and certificate specified", ClientHasCertificate: true, ServerHasCertificate: true, ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, ClientPSKIdentity: []byte{0x00}, ServerPSKIdentity: []byte{0x00}, WantClientError: errNoAvailablePSKCipherSuite, WantServerError: errNoAvailablePSKCipherSuite, }, { Name: "PSK and no identity specified", ClientHasCertificate: false, ServerHasCertificate: false, ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, ClientPSKIdentity: nil, ServerPSKIdentity: nil, WantClientError: errPSKAndIdentityMustBeSetForClient, WantServerError: errNoAvailablePSKCipherSuite, }, { Name: "No PSK and identity specified", ClientHasCertificate: false, ServerHasCertificate: false, ClientPSK: nil, ServerPSK: nil, ClientPSKIdentity: []byte{0x00}, ServerPSKIdentity: []byte{0x00}, WantClientError: errIdentityNoPSK, WantServerError: errIdentityNoPSK, }, } { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() type result struct { c *Conn err error } resultCh := make(chan result) go func() { client, err := testClient( ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{PSK: test.ClientPSK, PSKIdentityHint: test.ClientPSKIdentity}, test.ClientHasCertificate, ) resultCh <- result{client, err} }() _, err := testServer( ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{PSK: test.ServerPSK, PSKIdentityHint: test.ServerPSKIdentity}, test.ServerHasCertificate, ) if err != nil || test.WantServerError != nil { if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) { assert.Failf(t, "TestPSKConfiguration", "Server Error Mismatch '%s'", test.Name) } } res := <-resultCh if res.err != nil || test.WantClientError != nil { if !(res.err != nil && test.WantClientError != nil && res.err.Error() == test.WantClientError.Error()) { assert.Failf(t, "TestPSKConfiguration", "Client Error Mismatch '%s'", test.Name) } } } } func TestServerTimeout(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() cookie := make([]byte, 20) _, err := rand.Read(cookie) assert.NoError(t, err) var rand [28]byte random := handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand} cipherSuites := []CipherSuite{ &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}, &ciphersuite.TLSEcdheRsaWithAes128GcmSha256{}, } extensions := []extension.Extension{ &extension.SupportedSignatureAlgorithms{ SignatureHashAlgorithms: []signaturehash.Algorithm{ {Hash: hash.SHA256, Signature: signature.ECDSA}, {Hash: hash.SHA384, Signature: signature.ECDSA}, {Hash: hash.SHA512, Signature: signature.ECDSA}, {Hash: hash.SHA256, Signature: signature.RSA}, {Hash: hash.SHA384, Signature: signature.RSA}, {Hash: hash.SHA512, Signature: signature.RSA}, }, }, &extension.SupportedEllipticCurves{ EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384}, }, &extension.SupportedPointFormats{ PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, }, } record := &recordlayer.RecordLayer{ Header: recordlayer.Header{ SequenceNumber: 0, Version: protocol.Version1_2, }, Content: &handshake.Handshake{ // sequenceNumber and messageSequence line up, may need to be re-evaluated Header: handshake.Header{ MessageSequence: 0, }, Message: &handshake.MessageClientHello{ Version: protocol.Version1_2, Cookie: cookie, Random: random, CipherSuiteIDs: cipherSuiteIDs(cipherSuites), CompressionMethods: defaultCompressionMethods(), Extensions: extensions, }, }, } packet, err := record.Marshal() assert.NoError(t, err) ca, cb := dpipe.Pipe() defer func() { assert.NoError(t, ca.Close()) }() // Client reader caReadChan := make(chan []byte, 1000) go func() { for { data := make([]byte, 8192) n, err := ca.Read(data) if err != nil { return } caReadChan <- data[:n] } }() // Start sending ClientHello packets until server responds with first packet go func() { for { select { case <-time.After(10 * time.Millisecond): _, err := ca.Write(packet) if err != nil { return } case <-caReadChan: // Once we receive the first reply from the server, stop return } } }() ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) defer cancel() config := &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, FlightInterval: 100 * time.Millisecond, } _, serverErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) var netErr net.Error assert.ErrorAsf(t, serverErr, &netErr, "Client error exp(Temporary network error) failed(%v)", serverErr) assert.Truef(t, netErr.Timeout(), "Client error exp(Temporary network error) failed(%v)", serverErr) // Wait a little longer to ensure no additional messages have been sent by the server time.Sleep(300 * time.Millisecond) select { case msg := <-caReadChan: assert.Fail(t, "Expected no additional messages from server", "got: %+v", msg) default: } } func TestProtocolVersionValidation(t *testing.T) { //nolint:maintidx // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() cookie := make([]byte, 20) _, err := rand.Read(cookie) assert.NoError(t, err) var rand [28]byte random := handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand} config := &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, FlightInterval: 100 * time.Millisecond, } t.Run("Server", func(t *testing.T) { serverCases := map[string]struct { records []*recordlayer.RecordLayer }{ "ClientHelloVersion": { records: []*recordlayer.RecordLayer{ { Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageClientHello{ Version: protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade Cookie: cookie, Random: random, CipherSuiteIDs: []uint16{uint16((&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}).ID())}, CompressionMethods: defaultCompressionMethods(), }, }, }, }, }, "SecondsClientHelloVersion": { records: []*recordlayer.RecordLayer{ { Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageClientHello{ Version: protocol.Version1_2, Cookie: cookie, Random: random, CipherSuiteIDs: []uint16{uint16((&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}).ID())}, CompressionMethods: defaultCompressionMethods(), }, }, }, { Header: recordlayer.Header{ Version: protocol.Version1_2, SequenceNumber: 1, }, Content: &handshake.Handshake{ Header: handshake.Header{ MessageSequence: 1, }, Message: &handshake.MessageClientHello{ Version: protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade Cookie: cookie, Random: random, CipherSuiteIDs: []uint16{uint16((&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}).ID())}, CompressionMethods: defaultCompressionMethods(), }, }, }, }, }, } for name, serverCase := range serverCases { serverCase := serverCase t.Run(name, func(t *testing.T) { ca, cb := dpipe.Pipe() defer func() { assert.NoError(t, ca.Close()) }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() var wg sync.WaitGroup wg.Add(1) defer wg.Wait() go func() { defer wg.Done() _, err := testServer( ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true, ) assert.ErrorIs(t, err, errUnsupportedProtocolVersion) }() time.Sleep(50 * time.Millisecond) resp := make([]byte, 1024) for _, record := range serverCase.records { packet, err := record.Marshal() assert.NoError(t, err) _, werr := ca.Write(packet) assert.NoError(t, werr) n, rerr := ca.Read(resp[:cap(resp)]) assert.NoError(t, rerr) resp = resp[:n] } h := &recordlayer.Header{} assert.NoError(t, h.Unmarshal(resp)) assert.Equal(t, protocol.ContentTypeAlert, h.ContentType, "Peer must return alert to unsupported protocol version") }) } }) t.Run("Client", func(t *testing.T) { clientCases := map[string]struct { records []*recordlayer.RecordLayer }{ "ServerHelloVersion": { records: []*recordlayer.RecordLayer{ { Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageHelloVerifyRequest{ Version: protocol.Version1_2, Cookie: cookie, }, }, }, { Header: recordlayer.Header{ Version: protocol.Version1_2, SequenceNumber: 1, }, Content: &handshake.Handshake{ Header: handshake.Header{ MessageSequence: 1, }, Message: &handshake.MessageServerHello{ Version: protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade Random: random, CipherSuiteID: func() *uint16 { id := uint16(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) return &id }(), CompressionMethod: defaultCompressionMethods()[0], }, }, }, }, }, } for name, clientCase := range clientCases { clientCase := clientCase t.Run(name, func(t *testing.T) { ca, cb := dpipe.Pipe() defer func() { assert.NoError(t, ca.Close()) }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() var wg sync.WaitGroup wg.Add(1) defer wg.Wait() go func() { defer wg.Done() _, err := testClient(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) assert.ErrorIs(t, err, errUnsupportedProtocolVersion) }() time.Sleep(50 * time.Millisecond) for _, record := range clientCase.records { _, err := ca.Read(make([]byte, 1024)) assert.NoError(t, err) packet, err := record.Marshal() assert.NoError(t, err) _, err = ca.Write(packet) assert.NoError(t, err) } resp := make([]byte, 1024) n, err := ca.Read(resp) assert.NoError(t, err) resp = resp[:n] h := &recordlayer.Header{} assert.NoError(t, h.Unmarshal(resp)) assert.Equal(t, protocol.ContentTypeAlert, h.ContentType, "Peer must return alert to unsupported protocol version") }) } }) } func TestMultipleHelloVerifyRequest(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() cookies := [][]byte{ // first clientHello contains an empty cookie {}, } var packets [][]byte for i := 0; i < 2; i++ { cookie := make([]byte, 20) _, err := rand.Read(cookie) assert.NoError(t, err) cookies = append(cookies, cookie) record := &recordlayer.RecordLayer{ Header: recordlayer.Header{ SequenceNumber: uint64(i), //nolint:gosec // G101 Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Header: handshake.Header{ MessageSequence: uint16(i), //nolint:gosec // G115 }, Message: &handshake.MessageHelloVerifyRequest{ Version: protocol.Version1_2, Cookie: cookie, }, }, } packet, err := record.Marshal() assert.NoError(t, err) packets = append(packets, packet) } ca, cb := dpipe.Pipe() defer func() { assert.NoError(t, ca.Close()) }() ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() var wg sync.WaitGroup wg.Add(1) defer wg.Wait() go func() { defer wg.Done() _, _ = testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{}, false) }() for i, cookie := range cookies { // read client hello resp := make([]byte, 1024) n, err := cb.Read(resp) assert.NoError(t, err) record := &recordlayer.RecordLayer{} assert.NoError(t, record.Unmarshal(resp[:n])) clientHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello) assert.True(t, ok) assert.Equal(t, cookie, clientHello.Cookie) if len(packets) <= i { break } // write hello verify request _, err = cb.Write(packets[i]) assert.NoError(t, err) } cancel() } // Assert that a DTLS Server only responds with RenegotiationInfo if a ClientHello contained that // extension according to RFC5746 section 3.6, RFC5246 section 7.4.1.4 and RFC5746 section 4.2. func TestRenegotationInfo(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(10 * time.Second) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() resp := make([]byte, 1024) for _, testCase := range []struct { Name string ExpectRenegotiationInfo bool }{ { "Include RenegotiationInfo", true, }, { "No RenegotiationInfo", false, }, } { test := testCase t.Run(test.Name, func(t *testing.T) { ca, cb := dpipe.Pipe() defer func() { assert.NoError(t, ca.Close()) }() ctx, cancel := context.WithCancel(context.Background()) defer cancel() go func() { _, err := testServer( ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{}, true, ) assert.ErrorIs(t, err, context.Canceled) }() time.Sleep(50 * time.Millisecond) extensions := []extension.Extension{} if test.ExpectRenegotiationInfo { extensions = append(extensions, &extension.RenegotiationInfo{ RenegotiatedConnection: 0, }) } err := sendClientHello([]byte{}, ca, 0, extensions) assert.NoError(t, err) n, err := ca.Read(resp) assert.NoError(t, err) record := &recordlayer.RecordLayer{} assert.NoError(t, record.Unmarshal(resp[:n])) helloVerifyRequest, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest) assert.True(t, ok) err = sendClientHello(helloVerifyRequest.Cookie, ca, 1, extensions) assert.NoError(t, err) n, err = ca.Read(resp) assert.NoError(t, err) messages, err := recordlayer.UnpackDatagram(resp[:n]) assert.NoError(t, err) assert.NoError(t, record.Unmarshal(messages[0])) serverHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) assert.True(t, ok) actualNegotationInfo := false for _, v := range serverHello.Extensions { if _, ok := v.(*extension.RenegotiationInfo); ok { actualNegotationInfo = true } } assert.True(t, test.ExpectRenegotiationInfo == actualNegotationInfo, "NegotationInfo state in ServerHello is incorrect: expected(%t) actual(%t)", test.ExpectRenegotiationInfo, actualNegotationInfo) }) } } func TestServerNameIndicationExtension(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ServerName string Expected []byte IncludeSNI bool }{ { Name: "Server name is a valid hostname", ServerName: "example.com", Expected: []byte("example.com"), IncludeSNI: true, }, { Name: "Server name is an IP literal", ServerName: "1.2.3.4", Expected: []byte(""), IncludeSNI: false, }, { Name: "Server name is empty", ServerName: "", Expected: []byte(""), IncludeSNI: false, }, } { test := test t.Run(test.Name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() go func() { conf := &Config{ ServerName: test.ServerName, } _, _ = testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, false) }() // Receive ClientHello resp := make([]byte, 1024) n, err := cb.Read(resp) assert.NoError(t, err) r := &recordlayer.RecordLayer{} assert.NoError(t, r.Unmarshal(resp[:n])) clientHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello) assert.True(t, ok) gotSNI := false var actualServerName string for _, v := range clientHello.Extensions { if _, ok := v.(*extension.ServerName); ok { gotSNI = true extensionServerName, ok := v.(*extension.ServerName) assert.True(t, ok) actualServerName = extensionServerName.ServerName } } assert.Equalf(t, test.IncludeSNI, gotSNI, "TestSNI: expected SNI inclusion '%s'", test.Name) assert.Equalf(t, test.Expected, []byte(actualServerName), "TestSNI: server name mismatch '%s'", test.Name) }) } } func TestALPNExtension(t *testing.T) { //nolint:maintidx // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ClientProtocolNameList []string ServerProtocolNameList []string ExpectedProtocol string ExpectAlertFromClient bool ExpectAlertFromServer bool Alert alert.Description }{ { Name: "Negotiate a protocol", ClientProtocolNameList: []string{"http/1.1", "spd/1"}, ServerProtocolNameList: []string{"spd/1"}, ExpectedProtocol: "spd/1", ExpectAlertFromClient: false, ExpectAlertFromServer: false, Alert: 0, }, { Name: "Server doesn't support any", ClientProtocolNameList: []string{"http/1.1", "spd/1"}, ServerProtocolNameList: []string{}, ExpectedProtocol: "", ExpectAlertFromClient: false, ExpectAlertFromServer: false, Alert: 0, }, { Name: "Negotiate with higher server precedence", ClientProtocolNameList: []string{"http/1.1", "spd/1", "http/3"}, ServerProtocolNameList: []string{"ssh/2", "http/3", "spd/1"}, ExpectedProtocol: "http/3", ExpectAlertFromClient: false, ExpectAlertFromServer: false, Alert: 0, }, { Name: "Empty intersection", ClientProtocolNameList: []string{"http/1.1", "http/3"}, ServerProtocolNameList: []string{"ssh/2", "spd/1"}, ExpectedProtocol: "", ExpectAlertFromClient: false, ExpectAlertFromServer: true, Alert: alert.NoApplicationProtocol, }, { Name: "Multiple protocols in ServerHello", ClientProtocolNameList: []string{"http/1.1"}, ServerProtocolNameList: []string{"http/1.1"}, ExpectedProtocol: "http/1.1", ExpectAlertFromClient: true, ExpectAlertFromServer: false, Alert: alert.InternalError, }, } { test := test t.Run(test.Name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() go func() { conf := &Config{ SupportedProtocols: test.ClientProtocolNameList, } _, _ = testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, false) }() // Receive ClientHello resp := make([]byte, 1024) n, err := cb.Read(resp) assert.NoError(t, err) ctx2, cancel2 := context.WithTimeout(context.Background(), 10*time.Second) defer cancel2() ca2, cb2 := dpipe.Pipe() go func() { conf := &Config{ SupportedProtocols: test.ServerProtocolNameList, } _, err2 := testServer(ctx2, dtlsnet.PacketConnFromConn(cb2), cb2.RemoteAddr(), conf, true) if test.ExpectAlertFromServer { assert.NotErrorIs(t, err2, context.Canceled) } }() time.Sleep(50 * time.Millisecond) // Forward ClientHello _, err = ca2.Write(resp[:n]) assert.NoError(t, err) // Receive HelloVerify resp2 := make([]byte, 1024) n, err = ca2.Read(resp2) assert.NoError(t, err) // Forward HelloVerify _, err = cb.Write(resp2[:n]) assert.NoError(t, err) // Receive ClientHello resp3 := make([]byte, 1024) n, err = cb.Read(resp3) assert.NoError(t, err) // Forward ClientHello _, err = ca2.Write(resp3[:n]) assert.NoError(t, err) // Receive ServerHello resp4 := make([]byte, 1024) n, err = ca2.Read(resp4) assert.NoError(t, err) messages, err := recordlayer.UnpackDatagram(resp4[:n]) assert.NoError(t, err) record := &recordlayer.RecordLayer{} assert.NoError(t, record.Unmarshal(messages[0])) if test.ExpectAlertFromServer { //nolint:nestif a, ok := record.Content.(*alert.Alert) assert.True(t, ok) assert.Equalf(t, test.Alert, a.Description, "ALPN %v", test.Name) } else { serverHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) assert.True(t, ok) var negotiatedProtocol string for _, v := range serverHello.Extensions { if _, ok := v.(*extension.ALPN); ok { e, ok := v.(*extension.ALPN) assert.True(t, ok) negotiatedProtocol = e.ProtocolNameList[0] // Manipulate ServerHello if test.ExpectAlertFromClient { e.ProtocolNameList = append(e.ProtocolNameList, "oops") } } } assert.Equalf(t, test.ExpectedProtocol, negotiatedProtocol, "ALPN %v", test.Name) s, err := record.Marshal() assert.NoError(t, err) // Forward ServerHello _, err = cb.Write(s) assert.NoError(t, err) if test.ExpectAlertFromClient { resp5 := make([]byte, 1024) n, err = cb.Read(resp5) assert.NoError(t, err) r2 := &recordlayer.RecordLayer{} assert.NoError(t, r2.Unmarshal(resp5[:n])) a, ok := r2.Content.(*alert.Alert) assert.True(t, ok) assert.Equalf(t, test.Alert, a.Description, "ALPN %v", test.Name) } } time.Sleep(50 * time.Millisecond) // Give some time for returned errors }) } } // Make sure the supported_groups extension is not included in the ServerHello. func TestSupportedGroupsExtension(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() t.Run("ServerHello Supported Groups", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() go func() { _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{}, true) assert.ErrorIs(t, err, context.Canceled) }() extensions := []extension.Extension{ &extension.SupportedEllipticCurves{ EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384}, }, &extension.SupportedPointFormats{ PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, }, } time.Sleep(50 * time.Millisecond) resp := make([]byte, 1024) err := sendClientHello([]byte{}, ca, 0, extensions) assert.NoError(t, err) // Receive ServerHello n, err := ca.Read(resp) assert.NoError(t, err) record := &recordlayer.RecordLayer{} assert.NoError(t, record.Unmarshal(resp[:n])) helloVerifyRequest, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest) assert.True(t, ok, "Failed to cast MessageHelloVerifyRequest") err = sendClientHello(helloVerifyRequest.Cookie, ca, 1, extensions) assert.NoError(t, err) n, err = ca.Read(resp) assert.NoError(t, err) messages, err := recordlayer.UnpackDatagram(resp[:n]) assert.NoError(t, err) assert.NoError(t, record.Unmarshal(messages[0])) serverHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) assert.True(t, ok, "TestSupportedGroups: Failed to cast MessageServerHello") gotGroups := false for _, v := range serverHello.Extensions { if _, ok := v.(*extension.SupportedEllipticCurves); ok { gotGroups = true } } assert.False(t, gotGroups, "TestSupportedGroups: supported_groups extension was sent in ServerHello") }) } func TestSessionResume(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() t.Run("resumed", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() type result struct { c *Conn err error } clientRes := make(chan result, 1) ss := &memSessStore{} id, _ := hex.DecodeString("9b9fc92255634d9fb109febed42166717bb8ded8c738ba71bc7f2a0d9dae0306") secret, _ := hex.DecodeString( "2e942a37aca5241deb2295b5fcedac221c7078d2503d2b62aeb48c880d7da73c001238b708559686b9da6e829c05ead7", ) s := Session{ID: id, Secret: secret} ca, cb := dpipe.Pipe() _ = ss.Set(id, s) _ = ss.Set([]byte(ca.RemoteAddr().String()+"_example.com"), s) go func() { config := &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, ServerName: "example.com", SessionStore: ss, MTU: 100, } c, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config, false) clientRes <- result{c, err} }() config := &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, ServerName: "example.com", SessionStore: ss, MTU: 100, } server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) assert.NoError(t, err) state, ok := server.ConnectionState() assert.True(t, ok) actualSessionID := state.SessionID actualMasterSecret := state.masterSecret assert.Equal(t, actualSessionID, id, "TestSessionResumetion SessionID mismatch") assert.Equal(t, actualMasterSecret, secret, "TestSessionResumetion masterSecret mismatch") defer func() { assert.NoError(t, server.Close()) }() res := <-clientRes assert.NoError(t, res.err) assert.NoError(t, res.c.Close()) }) t.Run("new session", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() type result struct { c *Conn err error } clientRes := make(chan result, 1) s1 := &memSessStore{} s2 := &memSessStore{} ca, cb := dpipe.Pipe() go func() { config := &Config{ ServerName: "example.com", SessionStore: s1, } c, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config, false) clientRes <- result{c, err} }() config := &Config{ SessionStore: s2, } server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) assert.NoError(t, err) state, ok := server.ConnectionState() assert.True(t, ok) actualSessionID := state.SessionID actualMasterSecret := state.masterSecret ss, _ := s2.Get(actualSessionID) assert.Equal(t, actualMasterSecret, ss.Secret, "TestSessionResumetion masterSecret mismatch") defer func() { assert.NoError(t, server.Close()) }() res := <-clientRes assert.NoError(t, res.err) cs, _ := s1.Get([]byte(ca.RemoteAddr().String() + "_example.com")) assert.Equal(t, actualMasterSecret, cs.Secret, "TestSessionResumetion mismatch") assert.NoError(t, res.c.Close()) }) } type memSessStore struct { sync.Map } func (ms *memSessStore) Set(key []byte, s Session) error { k := hex.EncodeToString(key) ms.Store(k, s) return nil } func (ms *memSessStore) Get(key []byte) (Session, error) { k := hex.EncodeToString(key) v, ok := ms.Load(k) if !ok { return Session{}, nil } s, ok := v.(Session) if !ok { return Session{}, nil } return s, nil } func (ms *memSessStore) Del(key []byte) error { k := hex.EncodeToString(key) ms.Delete(k) return nil } // Assert that the server only uses CipherSuites with a hash+signature that matches // the certificate. As specified in rfc5246#section-7.4.3 // . func TestCipherSuiteMatchesCertificateType(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string cipherList []CipherSuiteID expectedCipher CipherSuiteID generateRSA bool }{ { Name: "ECDSA Certificate with RSA CipherSuite first", cipherList: []CipherSuiteID{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, expectedCipher: TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, }, { Name: "RSA Certificate with ECDSA CipherSuite first", cipherList: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, expectedCipher: TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, generateRSA: true, }, } { test := test t.Run(test.Name, func(t *testing.T) { clientErr := make(chan error, 1) client := make(chan *Conn, 1) ca, cb := dpipe.Pipe() go func() { c, err := testClient(context.TODO(), dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ CipherSuites: test.cipherList, }, false) clientErr <- err client <- c }() var ( signer crypto.Signer err error ) if test.generateRSA { signer, err = rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) } else { signer, err = ecdsa.GenerateKey(cryptoElliptic.P256(), rand.Reader) assert.NoError(t, err) } serverCert, err := selfsign.SelfSign(signer) assert.NoError(t, err) s, err := testServer(context.TODO(), dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ CipherSuites: test.cipherList, Certificates: []tls.Certificate{serverCert}, }, false) assert.NoError(t, err) assert.NoError(t, s.Close()) c := <-client assert.NoError(t, <-clientErr) assert.NoError(t, c.Close()) state, ok := c.ConnectionState() assert.True(t, ok) assert.Equal(t, test.expectedCipher, state.cipherSuite.ID()) }) } } // Test that we return the proper certificate if we are serving multiple ServerNames on a single Server. func TestMultipleServerCertificates(t *testing.T) { fooCert, err := selfsign.GenerateSelfSignedWithDNS("foo") assert.NoError(t, err) barCert, err := selfsign.GenerateSelfSignedWithDNS("bar") assert.NoError(t, err) caPool := x509.NewCertPool() for _, cert := range []tls.Certificate{fooCert, barCert} { certificate, err := x509.ParseCertificate(cert.Certificate[0]) assert.NoError(t, err) caPool.AddCert(certificate) } for _, test := range []struct { RequestServerName string ExpectedDNSName string }{ { "foo", "foo", }, { "bar", "bar", }, { "invalid", "foo", }, } { test := test t.Run(test.RequestServerName, func(t *testing.T) { clientErr := make(chan error, 2) client := make(chan *Conn, 1) ca, cb := dpipe.Pipe() go func() { clientConn, err := testClient(context.TODO(), dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ RootCAs: caPool, ServerName: test.RequestServerName, VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error { certificate, err := x509.ParseCertificate(rawCerts[0]) if err != nil { return err } if certificate.DNSNames[0] != test.ExpectedDNSName { return errWrongCert } return nil }, }, false) clientErr <- err client <- clientConn }() s, err := testServer(context.TODO(), dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{fooCert, barCert}, }, false) assert.NoError(t, err) assert.NoError(t, s.Close()) assert.NoError(t, <-clientErr) assert.NoError(t, (<-client).Close()) }) } } func TestEllipticCurveConfiguration(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ConfigCurves []elliptic.Curve HandshakeCurves []elliptic.Curve }{ { Name: "Curve defaulting", ConfigCurves: nil, HandshakeCurves: defaultCurves, }, { Name: "Single curve", ConfigCurves: []elliptic.Curve{elliptic.X25519}, HandshakeCurves: []elliptic.Curve{elliptic.X25519}, }, { Name: "Multiple curves", ConfigCurves: []elliptic.Curve{elliptic.P384, elliptic.X25519}, HandshakeCurves: []elliptic.Curve{elliptic.P384, elliptic.X25519}, }, } { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() type result struct { c *Conn err error } resultCh := make(chan result) go func() { client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves, }, true) resultCh <- result{client, err} }() server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves, }, true) assert.NoError(t, err) ok := len(test.ConfigCurves) == 0 || len(test.ConfigCurves) == len(test.HandshakeCurves) assert.True(t, ok, "Failed to default Elliptic curves") if len(test.ConfigCurves) != 0 { assert.Equal(t, len(test.HandshakeCurves), len(server.fsm.cfg.ellipticCurves), "Failed to configure Elliptic curves") for i, c := range test.ConfigCurves { assert.Equal(t, c, server.fsm.cfg.ellipticCurves[i], "Failed to maintain Elliptic curve order") } } res := <-resultCh assert.NoError(t, res.err, "Client error") defer func() { assert.NoError(t, server.Close()) assert.NoError(t, res.c.Close()) }() } } func TestSkipHelloVerify(t *testing.T) { report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() certificate, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) gotHello := make(chan struct{}) go func() { server, sErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{certificate}, LoggerFactory: logging.NewDefaultLoggerFactory(), InsecureSkipVerifyHello: true, }, false) assert.NoError(t, sErr) buf := make([]byte, 1024) _, sErr = server.Read(buf) //nolint:contextcheck assert.NoError(t, sErr) gotHello <- struct{}{} assert.NoError(t, server.Close()) //nolint:contextcheck }() client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ LoggerFactory: logging.NewDefaultLoggerFactory(), InsecureSkipVerify: true, }, false) assert.NoError(t, err) _, err = client.Write([]byte("hello")) assert.NoError(t, err) select { case <-gotHello: // OK case <-time.After(time.Second * 5): assert.Fail(t, "timeout") } assert.NoError(t, client.Close()) } type connWithCallback struct { net.Conn onWrite func([]byte) } func (c *connWithCallback) Write(b []byte) (int, error) { if c.onWrite != nil { c.onWrite(b) } return c.Conn.Write(b) } func TestApplicationDataQueueLimited(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() defer func() { assert.NoError(t, ca.Close()) }() defer func() { assert.NoError(t, cb.Close()) }() done := make(chan struct{}) go func() { serverCert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) cfg := &Config{} cfg.Certificates = []tls.Certificate{serverCert} dconn, err := createConn(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), cfg, false, nil) assert.NoError(t, err) go func() { for i := 0; i < 5; i++ { dconn.lock.RLock() qlen := len(dconn.encryptedPackets) dconn.lock.RUnlock() assert.GreaterOrEqual(t, maxAppDataPacketQueueSize, qlen, "too many encrypted packets enqueued") time.Sleep(1 * time.Second) } }() assert.Error(t, dconn.HandshakeContext(ctx)) close(done) }() extensions := []extension.Extension{} time.Sleep(50 * time.Millisecond) assert.NoError(t, sendClientHello([]byte{}, ca, 0, extensions)) time.Sleep(50 * time.Millisecond) for i := 0; i < 1000; i++ { // Send an application data packet packet, err := (&recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, SequenceNumber: uint64(3), Epoch: 1, // use an epoch greater than 0 }, Content: &protocol.ApplicationData{ Data: []byte{1, 2, 3, 4}, }, }).Marshal() assert.NoError(t, err) _, err = ca.Write(packet) assert.NoError(t, err) if i%100 == 0 { time.Sleep(10 * time.Millisecond) } } time.Sleep(1 * time.Second) assert.NoError(t, ca.Close()) <-done } func TestHelloRandom(t *testing.T) { report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() certificate, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) gotHello := make(chan struct{}) chRandom := [handshake.RandomBytesLength]byte{} _, err = rand.Read(chRandom[:]) assert.NoError(t, err) go func() { server, sErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ GetCertificate: func(chi *ClientHelloInfo) (*tls.Certificate, error) { if len(chi.CipherSuites) == 0 { return &certificate, nil } assert.Equal(t, chRandom[:], chi.RandomBytes[:]) return &certificate, nil }, LoggerFactory: logging.NewDefaultLoggerFactory(), }, false) assert.NoError(t, sErr) buf := make([]byte, 1024) _, sErr = server.Read(buf) //nolint:contextcheck assert.NoError(t, sErr) gotHello <- struct{}{} assert.NoError(t, server.Close()) //nolint:contextcheck }() client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ LoggerFactory: logging.NewDefaultLoggerFactory(), HelloRandomBytesGenerator: func() [handshake.RandomBytesLength]byte { return chRandom }, InsecureSkipVerify: true, }, false) assert.NoError(t, err) _, err = client.Write([]byte("hello")) assert.NoError(t, err) select { case <-gotHello: // OK case <-time.After(time.Second * 5): assert.Fail(t, "timeout") } assert.NoError(t, client.Close()) } func TestOnConnectionAttempt(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*20) defer cancel() var clientOnConnectionAttempt, serverOnConnectionAttempt atomic.Int32 ca, cb := dpipe.Pipe() clientErr := make(chan error, 1) go func() { _, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ OnConnectionAttempt: func(in net.Addr) error { clientOnConnectionAttempt.Store(1) assert.NotNil(t, in) return nil }, }, true) clientErr <- err }() expectedErr := &FatalError{} _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ OnConnectionAttempt: func(in net.Addr) error { serverOnConnectionAttempt.Store(1) assert.NotNil(t, in) return expectedErr }, }, true) assert.ErrorIs(t, err, expectedErr) assert.Error(t, <-clientErr) assert.Equal(t, int32(1), serverOnConnectionAttempt.Load(), "OnConnectionAttempt did not fire for server") assert.Equal(t, int32(0), clientOnConnectionAttempt.Load(), "OnConnectionAttempt fired for client") } func TestFragmentBuffer_Retransmission(t *testing.T) { fragmentBuffer := newFragmentBuffer() frag := []byte{ 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x30, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01, } _, isRetransmission, err := fragmentBuffer.push(frag) assert.NoError(t, err) assert.False(t, isRetransmission) v, _ := fragmentBuffer.pop() assert.NotNil(t, v) _, isRetransmission, err = fragmentBuffer.push(frag) assert.NoError(t, err) assert.True(t, isRetransmission) } func TestConnectionState(t *testing.T) { ca, cb := dpipe.Pipe() // Setup client clientCfg := &Config{} clientCert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) clientCfg.Certificates = []tls.Certificate{clientCert} clientCfg.InsecureSkipVerify = true client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), clientCfg) assert.NoError(t, err) defer func() { _ = client.Close() }() _, ok := client.ConnectionState() assert.False(t, ok) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() errorChannel := make(chan error) go func() { errC := client.HandshakeContext(ctx) errorChannel <- errC }() // Setup server server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{}, true) assert.NoError(t, err) defer func() { _ = server.Close() }() err = <-errorChannel assert.NoError(t, err) _, ok = client.ConnectionState() assert.True(t, ok) } func TestMultiHandshake(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 10).Stop() ca, cb := dpipe.Pipe() serverCert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{serverCert}, }) assert.NoError(t, err) go func() { _ = server.Handshake() }() clientCert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ Certificates: []tls.Certificate{clientCert}, }) assert.NoError(t, err) assert.Error(t, client.Handshake()) assert.Error(t, client.Handshake()) assert.NoError(t, server.Close()) assert.NoError(t, client.Close()) } func TestCloseDuringHandshake(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 10).Stop() serverCert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) for i := 0; i < 100; i++ { _, cb := dpipe.Pipe() server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{serverCert}, }) assert.NoError(t, err) waitChan := make(chan struct{}) go func() { close(waitChan) _ = server.Handshake() }() <-waitChan assert.NoError(t, server.Close()) } } func TestCloseWithoutHandshake(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 10).Stop() serverCert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) _, cb := dpipe.Pipe() server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{serverCert}, }) assert.NoError(t, err) assert.NoError(t, server.Close()) } golang-github-pion-dtls-v3-3.0.7/connection_id.go000066400000000000000000000056561507057460300216720ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "crypto/rand" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) // RandomCIDGenerator is a random Connection ID generator where CID is the // specified size. Specifying a size of 0 will indicate to peers that sending a // Connection ID is not necessary. func RandomCIDGenerator(size int) func() []byte { return func() []byte { cid := make([]byte, size) if _, err := rand.Read(cid); err != nil { panic(err) //nolint -- nonrecoverable } return cid } } // OnlySendCIDGenerator enables sending Connection IDs negotiated with a peer, // but indicates to the peer that sending Connection IDs in return is not // necessary. func OnlySendCIDGenerator() func() []byte { return func() []byte { return nil } } // cidDatagramRouter extracts connection IDs from incoming datagram payloads and // uses them to route to the proper connection. // NOTE: properly routing datagrams based on connection IDs requires using // constant size connection IDs. func cidDatagramRouter(size int) func([]byte) (string, bool) { return func(packet []byte) (string, bool) { pkts, err := recordlayer.ContentAwareUnpackDatagram(packet, size) if err != nil || len(pkts) < 1 { return "", false } for _, pkt := range pkts { h := &recordlayer.Header{ ConnectionID: make([]byte, size), } if err := h.Unmarshal(pkt); err != nil { continue } if h.ContentType != protocol.ContentTypeConnectionID { continue } return string(h.ConnectionID), true } return "", false } } // cidConnIdentifier extracts connection IDs from outgoing ServerHello records // and associates them with the associated connection. // NOTE: a ServerHello should always be the first record in a datagram if // multiple are present, so we avoid iterating through all packets if the first // is not a ServerHello. func cidConnIdentifier() func([]byte) (string, bool) { //nolint:cyclop return func(packet []byte) (string, bool) { pkts, err := recordlayer.UnpackDatagram(packet) if err != nil || len(pkts) < 1 { return "", false } var h recordlayer.Header if hErr := h.Unmarshal(pkts[0]); hErr != nil { return "", false } if h.ContentType != protocol.ContentTypeHandshake { return "", false } var hh handshake.Header var sh handshake.MessageServerHello for _, pkt := range pkts { if hhErr := hh.Unmarshal(pkt[recordlayer.FixedHeaderSize:]); hhErr != nil { continue } if err = sh.Unmarshal(pkt[recordlayer.FixedHeaderSize+handshake.HeaderLength:]); err == nil { break } } if err != nil { return "", false } for _, ext := range sh.Extensions { if e, ok := ext.(*extension.ConnectionID); ok { return string(e.CID), true } } return "", false } } golang-github-pion-dtls-v3-3.0.7/connection_id_test.go000066400000000000000000000171141507057460300227210ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "testing" "time" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" "github.com/stretchr/testify/assert" ) func TestRandomConnectionIDGenerator(t *testing.T) { cases := map[string]struct { reason string size int }{ "LengthMatch": { reason: "Zero size should match length of generated CID.", size: 0, }, "LengthMatchSome": { reason: "Non-zero size should match length of generated CID with non-zero.", size: 8, }, } for name, tc := range cases { t.Run(name, func(t *testing.T) { assert.Equal(t, tc.size, len(RandomCIDGenerator(tc.size)()), "%s\nRandomCIDGenerator mismatch", tc.reason) }) } } func TestOnlySendCIDGenerator(t *testing.T) { cases := map[string]struct { reason string }{ "LengthMatch": { reason: "CID length should always be zero.", }, } for name, tc := range cases { t.Run(name, func(t *testing.T) { assert.Equalf(t, 0, len(OnlySendCIDGenerator()()), "%s\nOnlySendCIDGenerator mismatch", tc.reason) }) } } func TestCIDDatagramRouter(t *testing.T) { cid := []byte("abcd1234") cidLen := 8 appRecord, err := (&recordlayer.RecordLayer{ Header: recordlayer.Header{ Epoch: 1, Version: protocol.Version1_2, }, Content: &protocol.ApplicationData{ Data: []byte("application data"), }, }).Marshal() assert.NoError(t, err) appData, err := (&protocol.ApplicationData{ Data: []byte("some data"), }).Marshal() assert.NoError(t, err) inner, err := (&recordlayer.InnerPlaintext{ Content: appData, RealType: protocol.ContentTypeApplicationData, }).Marshal() assert.NoError(t, err) cidHeader, err := (&recordlayer.Header{ Epoch: 1, Version: protocol.Version1_2, ContentType: protocol.ContentTypeConnectionID, ContentLen: uint16(len(inner)), //nolint:gosec // G115 ConnectionID: cid, SequenceNumber: 1, }).Marshal() assert.NoError(t, err) cases := map[string]struct { reason string size int datagram []byte ok bool want string }{ "EmptyDatagram": { reason: "If datagram is empty, we cannot extract an identifier", size: cidLen, datagram: []byte{}, ok: false, want: "", }, "NotADTLSRecord": { reason: "If datagram is not a DTLS record, we cannot extract an identifier", size: cidLen, datagram: []byte("not a DTLS record"), ok: false, want: "", }, "NotAConnectionIDDatagram": { reason: "If datagram does not contain any Connection ID records, we cannot extract an identifier", size: cidLen, datagram: appRecord, ok: false, want: "", }, "OneRecordConnectionID": { reason: "If datagram contains one Connection ID record, we should be able to extract it.", size: cidLen, datagram: append(cidHeader, inner...), ok: true, want: string(cid), }, "OneRecordConnectionIDAltLength": { //nolint:lll reason: "If datagram contains one Connection ID record, but it has the wrong length we should not be able to extract it.", size: cidLen, datagram: func() []byte { altCIDHeader, err := (&recordlayer.Header{ Epoch: 1, Version: protocol.Version1_2, ContentType: protocol.ContentTypeConnectionID, ContentLen: uint16(len(inner)), //nolint:gosec // G115 ConnectionID: []byte("abcd"), SequenceNumber: 1, }).Marshal() assert.NoError(t, err) return append(altCIDHeader, inner...) }(), ok: false, want: "", }, "MultipleRecordOneConnectionID": { //nolint:lll reason: "If datagram contains multiple records and one is a Connection ID record, we should be able to extract it.", size: 8, datagram: append(append(appRecord, cidHeader...), inner...), ok: true, want: string(cid), }, "MultipleRecordMultipleConnectionID": { //nolint:lll reason: "If datagram contains multiple records and multiple are Connection ID records, we should extract the first one.", size: 8, datagram: append(append(append(appRecord, func() []byte { altCIDHeader, err := (&recordlayer.Header{ Epoch: 1, Version: protocol.Version1_2, ContentType: protocol.ContentTypeConnectionID, ContentLen: uint16(len(inner)), //nolint:gosec // G115 ConnectionID: []byte("1234abcd"), SequenceNumber: 1, }).Marshal() assert.NoError(t, err) return append(altCIDHeader, inner...) }()...), cidHeader...), inner...), ok: true, want: "1234abcd", }, } for name, tc := range cases { t.Run(name, func(t *testing.T) { cid, ok := cidDatagramRouter(tc.size)(tc.datagram) assert.Equal(t, tc.ok, ok, "%s\ncidDatagramRouter mismatch", tc.reason) assert.Equal(t, tc.want, cid, "%s\ncidDatagramRouter mismatch", tc.reason) }) } } func TestCIDConnIdentifier(t *testing.T) { cid := []byte("abcd1234") cs := uint16(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) sh, err := (&recordlayer.RecordLayer{ Header: recordlayer.Header{ Epoch: 0, Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageServerHello{ Version: protocol.Version1_2, Random: handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: [28]byte{}}, SessionID: []byte("hello"), CipherSuiteID: &cs, CompressionMethod: defaultCompressionMethods()[0], Extensions: []extension.Extension{ &extension.ConnectionID{ CID: cid, }, }, }, }, }).Marshal() assert.NoError(t, err) appRecord, err := (&recordlayer.RecordLayer{ Header: recordlayer.Header{ Epoch: 1, Version: protocol.Version1_2, }, Content: &protocol.ApplicationData{ Data: []byte("application data"), }, }).Marshal() assert.NoError(t, err) cases := map[string]struct { reason string datagram []byte ok bool want string }{ "EmptyDatagram": { reason: "If datagram is empty, we cannot extract an identifier", datagram: []byte{}, ok: false, want: "", }, "NotADTLSRecord": { reason: "If datagram is not a DTLS record, we cannot extract an identifier", datagram: []byte("not a DTLS record"), ok: false, want: "", }, "NotAServerhelloDatagram": { reason: "If datagram does not contain any ServerHello record, we cannot extract an identifier", datagram: appRecord, ok: false, want: "", }, "OneRecordServerHello": { reason: "If datagram contains one ServerHello record, we should be able to extract an identifier.", datagram: sh, ok: true, want: string(cid), }, "MultipleRecordFirstServerHello": { //nolint:lll reason: "If datagram contains multiple records and the first is a ServerHello record, we should be able to extract an identifier.", datagram: append(sh, appRecord...), ok: true, want: string(cid), }, "MultipleRecordNotFirstServerHello": { //nolint:lll reason: "If datagram contains multiple records and the first is not a ServerHello record, we should not be able to extract an identifier.", datagram: append(appRecord, sh...), ok: false, want: "", }, } for name, tc := range cases { t.Run(name, func(t *testing.T) { cid, ok := cidConnIdentifier()(tc.datagram) assert.Equalf(t, tc.ok, ok, "%s\ncidConnIdentifier mismatch", tc.reason) assert.Equalf(t, tc.want, cid, "%s\ncidConnIdentifier mismatch", tc.reason) }) } } golang-github-pion-dtls-v3-3.0.7/crypto.go000066400000000000000000000156411507057460300203720ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/asn1" "encoding/binary" "math/big" "time" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/crypto/hash" ) type ecdsaSignature struct { R, S *big.Int } func valueKeyMessage(clientRandom, serverRandom, publicKey []byte, namedCurve elliptic.Curve) []byte { serverECDHParams := make([]byte, 4) serverECDHParams[0] = 3 // named curve binary.BigEndian.PutUint16(serverECDHParams[1:], uint16(namedCurve)) serverECDHParams[3] = byte(len(publicKey)) plaintext := []byte{} plaintext = append(plaintext, clientRandom...) plaintext = append(plaintext, serverRandom...) plaintext = append(plaintext, serverECDHParams...) plaintext = append(plaintext, publicKey...) return plaintext } // If the client provided a "signature_algorithms" extension, then all // certificates provided by the server MUST be signed by a // hash/signature algorithm pair that appears in that extension // // https://tools.ietf.org/html/rfc5246#section-7.4.2 func generateKeySignature( clientRandom, serverRandom, publicKey []byte, namedCurve elliptic.Curve, signer crypto.Signer, hashAlgorithm hash.Algorithm, ) ([]byte, error) { msg := valueKeyMessage(clientRandom, serverRandom, publicKey, namedCurve) switch signer.Public().(type) { case ed25519.PublicKey: // https://crypto.stackexchange.com/a/55483 return signer.Sign(rand.Reader, msg, crypto.Hash(0)) case *ecdsa.PublicKey: hashed := hashAlgorithm.Digest(msg) return signer.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) case *rsa.PublicKey: hashed := hashAlgorithm.Digest(msg) return signer.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) } return nil, errKeySignatureGenerateUnimplemented } //nolint:dupl,cyclop func verifyKeySignature( message, remoteKeySignature []byte, hashAlgorithm hash.Algorithm, rawCertificates [][]byte, ) error { if len(rawCertificates) == 0 { return errLengthMismatch } certificate, err := x509.ParseCertificate(rawCertificates[0]) if err != nil { return err } switch pubKey := certificate.PublicKey.(type) { case ed25519.PublicKey: if ok := ed25519.Verify(pubKey, message, remoteKeySignature); !ok { return errKeySignatureMismatch } return nil case *ecdsa.PublicKey: ecdsaSig := &ecdsaSignature{} if _, err := asn1.Unmarshal(remoteKeySignature, ecdsaSig); err != nil { return err } if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { return errInvalidECDSASignature } hashed := hashAlgorithm.Digest(message) if !ecdsa.Verify(pubKey, hashed, ecdsaSig.R, ecdsaSig.S) { return errKeySignatureMismatch } return nil case *rsa.PublicKey: hashed := hashAlgorithm.Digest(message) if rsa.VerifyPKCS1v15(pubKey, hashAlgorithm.CryptoHash(), hashed, remoteKeySignature) != nil { return errKeySignatureMismatch } return nil } return errKeySignatureVerifyUnimplemented } // If the server has sent a CertificateRequest message, the client MUST send the Certificate // message. The ClientKeyExchange message is now sent, and the content // of that message will depend on the public key algorithm selected // between the ClientHello and the ServerHello. If the client has sent // a certificate with signing ability, a digitally-signed // CertificateVerify message is sent to explicitly verify possession of // the private key in the certificate. // https://tools.ietf.org/html/rfc5246#section-7.3 func generateCertificateVerify( handshakeBodies []byte, signer crypto.Signer, hashAlgorithm hash.Algorithm, ) ([]byte, error) { if _, ok := signer.Public().(ed25519.PublicKey); ok { // https://pkg.go.dev/crypto/ed25519#PrivateKey.Sign // Sign signs the given message with priv. Ed25519 performs two passes over // messages to be signed and therefore cannot handle pre-hashed messages. return signer.Sign(rand.Reader, handshakeBodies, crypto.Hash(0)) } hashed := hashAlgorithm.Digest(handshakeBodies) switch signer.Public().(type) { case *ecdsa.PublicKey: return signer.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) case *rsa.PublicKey: return signer.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) } return nil, errInvalidSignatureAlgorithm } //nolint:dupl,cyclop func verifyCertificateVerify( handshakeBodies []byte, hashAlgorithm hash.Algorithm, remoteKeySignature []byte, rawCertificates [][]byte, ) error { if len(rawCertificates) == 0 { return errLengthMismatch } certificate, err := x509.ParseCertificate(rawCertificates[0]) if err != nil { return err } switch pubKey := certificate.PublicKey.(type) { case ed25519.PublicKey: if ok := ed25519.Verify(pubKey, handshakeBodies, remoteKeySignature); !ok { return errKeySignatureMismatch } return nil case *ecdsa.PublicKey: ecdsaSig := &ecdsaSignature{} if _, err := asn1.Unmarshal(remoteKeySignature, ecdsaSig); err != nil { return err } if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { return errInvalidECDSASignature } hash := hashAlgorithm.Digest(handshakeBodies) if !ecdsa.Verify(pubKey, hash, ecdsaSig.R, ecdsaSig.S) { return errKeySignatureMismatch } return nil case *rsa.PublicKey: hash := hashAlgorithm.Digest(handshakeBodies) if rsa.VerifyPKCS1v15(pubKey, hashAlgorithm.CryptoHash(), hash, remoteKeySignature) != nil { return errKeySignatureMismatch } return nil } return errKeySignatureVerifyUnimplemented } func loadCerts(rawCertificates [][]byte) ([]*x509.Certificate, error) { if len(rawCertificates) == 0 { return nil, errLengthMismatch } certs := make([]*x509.Certificate, 0, len(rawCertificates)) for _, rawCert := range rawCertificates { cert, err := x509.ParseCertificate(rawCert) if err != nil { return nil, err } certs = append(certs, cert) } return certs, nil } func verifyClientCert(rawCertificates [][]byte, roots *x509.CertPool) (chains [][]*x509.Certificate, err error) { certificate, err := loadCerts(rawCertificates) if err != nil { return nil, err } intermediateCAPool := x509.NewCertPool() for _, cert := range certificate[1:] { intermediateCAPool.AddCert(cert) } opts := x509.VerifyOptions{ Roots: roots, CurrentTime: time.Now(), Intermediates: intermediateCAPool, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, } return certificate[0].Verify(opts) } func verifyServerCert( rawCertificates [][]byte, roots *x509.CertPool, serverName string, ) (chains [][]*x509.Certificate, err error) { certificate, err := loadCerts(rawCertificates) if err != nil { return nil, err } intermediateCAPool := x509.NewCertPool() for _, cert := range certificate[1:] { intermediateCAPool.AddCert(cert) } opts := x509.VerifyOptions{ Roots: roots, CurrentTime: time.Now(), DNSName: serverName, Intermediates: intermediateCAPool, } return certificate[0].Verify(opts) } golang-github-pion-dtls-v3-3.0.7/crypto_test.go000066400000000000000000000110521507057460300214210ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "crypto/x509" "encoding/pem" "testing" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/crypto/hash" "github.com/stretchr/testify/assert" ) // nolint: gosec const rawPrivateKey = ` -----BEGIN RSA PRIVATE KEY----- MIIEowIBAAKCAQEAxIA2BrrnR2sIlATsp7aRBD/3krwZ7vt9dNeoDQAee0s6SuYP 6MBx/HPnAkwNvPS90R05a7pwRkoT6Ur4PfPhCVlUe8lV+0Eto3ZSEeHz3HdsqlM3 bso67L7Dqrc7MdVstlKcgJi8yeAoGOIL9/igOv0XBFCeznm9nznx6mnsR5cugw+1 ypXelaHmBCLV7r5SeVSh57+KhvZGbQ2fFpUaTPegRpJZXBNS8lSeWvtOv9d6N5UB ROTAJodMZT5AfX0jB0QB9IT/0I96H6BSENH08NXOeXApMuLKvnAf361rS7cRAfRL rWZqERMP4u6Cnk0Cnckc3WcW27kGGIbtwbqUIQIDAQABAoIBAGF7OVIdZp8Hejn0 N3L8HvT8xtUEe9kS6ioM0lGgvX5s035Uo4/T6LhUx0VcdXRH9eLHnLTUyN4V4cra ZkxVsE3zAvZl60G6E+oDyLMWZOP6Wu4kWlub9597A5atT7BpMIVCdmFVZFLB4SJ3 AXkC3nplFAYP+Lh1rJxRIrIn2g+pEeBboWbYA++oDNuMQffDZaokTkJ8Bn1JZYh0 xEXKY8Bi2Egd5NMeZa1UFO6y8tUbZfwgVs6Enq5uOgtfayq79vZwyjj1kd29MBUD 8g8byV053ZKxbUOiOuUts97eb+fN3DIDRTcT2c+lXt/4C54M1FclJAbtYRK/qwsl pYWKQAECgYEA4ZUbqQnTo1ICvj81ifGrz+H4LKQqe92Hbf/W51D/Umk2kP702W22 HP4CvrJRtALThJIG9m2TwUjl/WAuZIBrhSAbIvc3Fcoa2HjdRp+sO5U1ueDq7d/S Z+PxRI8cbLbRpEdIaoR46qr/2uWZ943PHMv9h4VHPYn1w8b94hwD6vkCgYEA3v87 mFLzyM9ercnEv9zHMRlMZFQhlcUGQZvfb8BuJYl/WogyT6vRrUuM0QXULNEPlrin mBQTqc1nCYbgkFFsD2VVt1qIyiAJsB9MD1LNV6YuvE7T2KOSadmsA4fa9PUqbr71 hf3lTTq+LeR09LebO7WgSGYY+5YKVOEGpYMR1GkCgYEAxPVQmk3HKHEhjgRYdaG5 lp9A9ZE8uruYVJWtiHgzBTxx9TV2iST+fd/We7PsHFTfY3+wbpcMDBXfIVRKDVwH BMwchXH9+Ztlxx34bYJaegd0SmA0Hw9ugWEHNgoSEmWpM1s9wir5/ELjc7dGsFtz uzvsl9fpdLSxDYgAAdzeGtkCgYBAzKIgrVox7DBzB8KojhtD5ToRnXD0+H/M6OKQ srZPKhlb0V/tTtxrIx0UUEFLlKSXA6mPw6XDHfDnD86JoV9pSeUSlrhRI+Ysy6tq eIE7CwthpPZiaYXORHZ7wCqcK/HcpJjsCs9rFbrV0yE5S3FMdIbTAvgXg44VBB7O UbwIoQKBgDuY8gSrA5/A747wjjmsdRWK4DMTMEV4eCW1BEP7Tg7Cxd5n3xPJiYhr nhLGN+mMnVIcv2zEMS0/eNZr1j/0BtEdx+3IC6Eq+ONY0anZ4Irt57/5QeKgKn/L JPhfPySIPG4UmwE4gW8t79vfOKxnUu2fDD1ZXUYopan6EckACNH/ -----END RSA PRIVATE KEY----- ` func TestGenerateKeySignature(t *testing.T) { block, _ := pem.Decode([]byte(rawPrivateKey)) key, err := x509.ParsePKCS1PrivateKey(block.Bytes) assert.NoError(t, err) clientRandom := []byte{ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, } serverRandom := []byte{ 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, } publicKey := []byte{ 0x20, 0x9f, 0xd7, 0xad, 0x6d, 0xcf, 0xf4, 0x29, 0x8d, 0xd3, 0xf9, 0x6d, 0x5b, 0x1b, 0x2a, 0xf9, 0x10, 0xa0, 0x53, 0x5b, 0x14, 0x88, 0xd7, 0xf8, 0xfa, 0xbb, 0x34, 0x9a, 0x98, 0x28, 0x80, 0xb6, 0x15, } expectedSignature := []byte{ 0x6f, 0x47, 0x97, 0x85, 0xcc, 0x76, 0x50, 0x93, 0xbd, 0xe2, 0x6a, 0x69, 0x0b, 0xc3, 0x03, 0xd1, 0xb7, 0xe4, 0xab, 0x88, 0x7b, 0xa6, 0x52, 0x80, 0xdf, 0xaa, 0x25, 0x7a, 0xdb, 0x29, 0x32, 0xe4, 0xd8, 0x28, 0x28, 0xb3, 0xe8, 0x04, 0x3c, 0x38, 0x16, 0xfc, 0x78, 0xe9, 0x15, 0x7b, 0xc5, 0xbd, 0x7d, 0xfc, 0xcd, 0x83, 0x00, 0x57, 0x4a, 0x3c, 0x23, 0x85, 0x75, 0x6b, 0x37, 0xd5, 0x89, 0x72, 0x73, 0xf0, 0x44, 0x8c, 0x00, 0x70, 0x1f, 0x6e, 0xa2, 0x81, 0xd0, 0x09, 0xc5, 0x20, 0x36, 0xab, 0x23, 0x09, 0x40, 0x1f, 0x4d, 0x45, 0x96, 0x62, 0xbb, 0x81, 0xb0, 0x30, 0x72, 0xad, 0x3a, 0x0a, 0xac, 0x31, 0x63, 0x40, 0x52, 0x0a, 0x27, 0xf3, 0x34, 0xde, 0x27, 0x7d, 0xb7, 0x54, 0xff, 0x0f, 0x9f, 0x5a, 0xfe, 0x07, 0x0f, 0x4e, 0x9f, 0x53, 0x04, 0x34, 0x62, 0xf4, 0x30, 0x74, 0x83, 0x35, 0xfc, 0xe4, 0x7e, 0xbf, 0x5a, 0xc4, 0x52, 0xd0, 0xea, 0xf9, 0x61, 0x4e, 0xf5, 0x1c, 0x0e, 0x58, 0x02, 0x71, 0xfb, 0x1f, 0x34, 0x55, 0xe8, 0x36, 0x70, 0x3c, 0xc1, 0xcb, 0xc9, 0xb7, 0xbb, 0xb5, 0x1c, 0x44, 0x9a, 0x6d, 0x88, 0x78, 0x98, 0xd4, 0x91, 0x2e, 0xeb, 0x98, 0x81, 0x23, 0x30, 0x73, 0x39, 0x43, 0xd5, 0xbb, 0x70, 0x39, 0xba, 0x1f, 0xdb, 0x70, 0x9f, 0x91, 0x83, 0x56, 0xc2, 0xde, 0xed, 0x17, 0x6d, 0x2c, 0x3e, 0x21, 0xea, 0x36, 0xb4, 0x91, 0xd8, 0x31, 0x05, 0x60, 0x90, 0xfd, 0xc6, 0x74, 0xa9, 0x7b, 0x18, 0xfc, 0x1c, 0x6a, 0x1c, 0x6e, 0xec, 0xd3, 0xc1, 0xc0, 0x0d, 0x11, 0x25, 0x48, 0x37, 0x3d, 0x45, 0x11, 0xa2, 0x31, 0x14, 0x0a, 0x66, 0x9f, 0xd8, 0xac, 0x74, 0xa2, 0xcd, 0xc8, 0x79, 0xb3, 0x9e, 0xc6, 0x66, 0x25, 0xcf, 0x2c, 0x87, 0x5e, 0x5c, 0x36, 0x75, 0x86, } signature, err := generateKeySignature(clientRandom, serverRandom, publicKey, elliptic.X25519, key, hash.SHA256) assert.NoError(t, err) assert.Equal(t, expectedSignature, signature) } golang-github-pion-dtls-v3-3.0.7/dtls.go000066400000000000000000000002731507057460300200130ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package dtls implements Datagram Transport Layer Security (DTLS) 1.2 package dtls golang-github-pion-dtls-v3-3.0.7/e2e/000077500000000000000000000000001507057460300171675ustar00rootroot00000000000000golang-github-pion-dtls-v3-3.0.7/e2e/Dockerfile000066400000000000000000000004161507057460300211620ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT FROM docker.io/library/golang:1.24-bullseye COPY . /go/src/github.com/pion/dtls WORKDIR /go/src/github.com/pion/dtls/e2e CMD ["go", "test", "-tags=openssl", "-v", "."] golang-github-pion-dtls-v3-3.0.7/e2e/e2e.go000066400000000000000000000002511507057460300201670ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package e2e contains end to end tests for pion/dtls package e2e golang-github-pion-dtls-v3-3.0.7/e2e/e2e_lossy_test.go000066400000000000000000000131661507057460300224700ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package e2e import ( "crypto/tls" "fmt" "math/rand" "testing" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/pkg/crypto/selfsign" dtlsnet "github.com/pion/dtls/v3/pkg/net" transportTest "github.com/pion/transport/v3/test" "github.com/stretchr/testify/assert" ) const ( flightInterval = time.Millisecond * 100 lossyTestTimeout = 30 * time.Second ) // DTLS Client/Server over a lossy transport, just asserts it can handle at increasing increments func TestPionE2ELossy(t *testing.T) { //nolint:cyclop // Check for leaking routines report := transportTest.CheckRoutines(t) defer report() type runResult struct { dtlsConn *dtls.Conn err error } serverCert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) clientCert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) for _, test := range []struct { LossChanceRange int DoClientAuth bool CipherSuites []dtls.CipherSuiteID MTU int DisableServerFlightInterval bool }{ { LossChanceRange: 0, }, { LossChanceRange: 10, }, { LossChanceRange: 20, }, { LossChanceRange: 50, }, { LossChanceRange: 0, DoClientAuth: true, }, { LossChanceRange: 10, DoClientAuth: true, }, { LossChanceRange: 20, DoClientAuth: true, }, { LossChanceRange: 50, DoClientAuth: true, }, { LossChanceRange: 0, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, }, { LossChanceRange: 10, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, }, { LossChanceRange: 20, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, }, { LossChanceRange: 50, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, }, { LossChanceRange: 10, MTU: 100, DoClientAuth: true, }, { LossChanceRange: 20, MTU: 100, DoClientAuth: true, }, { LossChanceRange: 50, MTU: 100, DoClientAuth: true, }, // Incoming retransmitted handshakes should cause us to retransmit. Disabling the FlightInterval on one side // means that a incoming re-transmissions causes the retransmission to be fired { LossChanceRange: 10, DisableServerFlightInterval: true, }, { LossChanceRange: 20, DisableServerFlightInterval: true, }, { LossChanceRange: 50, DisableServerFlightInterval: true, }, } { name := fmt.Sprintf("Loss%d_MTU%d", test.LossChanceRange, test.MTU) if test.DoClientAuth { name += "_WithCliAuth" } for _, ciph := range test.CipherSuites { name += "_With" + ciph.String() } if test.DisableServerFlightInterval { name += "_WithNoServerFlightInterval" } test := test t.Run(name, func(t *testing.T) { // Limit runtime in case of deadlocks lim := transportTest.TimeOut(lossyTestTimeout + time.Second) defer lim.Stop() chosenLoss := rand.Intn(9) + test.LossChanceRange //nolint:gosec serverDone := make(chan runResult) clientDone := make(chan runResult) br := transportTest.NewBridge() assert.NoError(t, br.SetLossChance(chosenLoss)) go func() { cfg := &dtls.Config{ FlightInterval: flightInterval, CipherSuites: test.CipherSuites, InsecureSkipVerify: true, MTU: test.MTU, DisableRetransmitBackoff: true, } if test.DoClientAuth { cfg.Certificates = []tls.Certificate{clientCert} } client, startupErr := dtls.Client(dtlsnet.PacketConnFromConn(br.GetConn0()), br.GetConn0().RemoteAddr(), cfg) clientDone <- runResult{client, startupErr} }() go func() { cfg := &dtls.Config{ Certificates: []tls.Certificate{serverCert}, FlightInterval: flightInterval, MTU: test.MTU, DisableRetransmitBackoff: true, } if test.DoClientAuth { cfg.ClientAuth = dtls.RequireAnyClientCert } if test.DisableServerFlightInterval { cfg.FlightInterval = time.Hour } server, startupErr := dtls.Server(dtlsnet.PacketConnFromConn(br.GetConn1()), br.GetConn1().RemoteAddr(), cfg) serverDone <- runResult{server, startupErr} }() testTimer := time.NewTimer(lossyTestTimeout) var serverConn, clientConn *dtls.Conn defer func() { if serverConn != nil { assert.NoError(t, serverConn.Close()) } if clientConn != nil { assert.NoError(t, clientConn.Close()) } }() for { if serverConn != nil && clientConn != nil { break } br.Tick() select { case serverResult := <-serverDone: if serverResult.err != nil { assert.Failf(t, "Fail, serverError", "clientComplete(%t) serverComplete(%t) LossChance(%d) error(%v)", clientConn != nil, serverConn != nil, chosenLoss, serverResult.err) return } serverConn = serverResult.dtlsConn case clientResult := <-clientDone: if clientResult.err != nil { assert.Failf(t, "Fail, clientError", "clientComplete(%t) serverComplete(%t) LossChance(%d) error(%v)", clientConn != nil, serverConn != nil, chosenLoss, clientResult.err) return } clientConn = clientResult.dtlsConn case <-testTimer.C: assert.Failf(t, "Test expired", "clientComplete(%t) serverComplete(%t) LossChance(%d)", clientConn != nil, serverConn != nil, chosenLoss) return case <-time.After(10 * time.Millisecond): } } }) } } golang-github-pion-dtls-v3-3.0.7/e2e/e2e_openssl_test.go000066400000000000000000000206771507057460300230070ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build openssl && !js // +build openssl,!js package e2e import ( "crypto/x509" "encoding/pem" "errors" "fmt" "io/ioutil" "net" "os" "os/exec" "regexp" "strings" "testing" "time" "github.com/pion/dtls/v3" ) func serverOpenSSL(c *comm) { go func() { c.serverMutex.Lock() defer c.serverMutex.Unlock() cfg := c.serverConfig // create openssl arguments args := []string{ "s_server", "-dtls1_2", "-quiet", "-verify_quiet", "-verify_return_error", fmt.Sprintf("-accept=%d", c.serverPort), } ciphers := ciphersOpenSSL(cfg) if ciphers != "" { args = append(args, fmt.Sprintf("-cipher=%s", ciphers)) } // psk arguments if cfg.PSK != nil { psk, err := cfg.PSK(nil) if err != nil { c.errChan <- err return } args = append(args, fmt.Sprintf("-psk=%X", psk)) if len(cfg.PSKIdentityHint) > 0 { args = append(args, fmt.Sprintf("-psk_hint=%s", cfg.PSKIdentityHint)) } } // certs arguments if len(cfg.Certificates) > 0 { // create temporary cert files certPEM, keyPEM, err := writeTempPEM(cfg) if err != nil { c.errChan <- err return } args = append(args, fmt.Sprintf("-cert=%s", certPEM), fmt.Sprintf("-key=%s", keyPEM)) defer func() { _ = os.Remove(certPEM) _ = os.Remove(keyPEM) }() } else { args = append(args, "-nocert") } // launch command // #nosec G204 cmd := exec.Command("openssl", args...) var inner net.Conn inner, c.serverConn = net.Pipe() cmd.Stdin = inner cmd.Stdout = inner cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { c.errChan <- err _ = inner.Close() return } // Ensure that server has started time.Sleep(500 * time.Millisecond) c.serverReady <- struct{}{} simpleReadWrite(c.errChan, c.serverChan, c.serverConn, c.messageRecvCount) c.serverDone <- cmd.Process.Kill() close(c.serverDone) }() } func clientOpenSSL(c *comm) { select { case <-c.serverReady: // OK case <-time.After(time.Second): c.errChan <- errors.New("waiting on serverReady err: timeout") } c.clientMutex.Lock() defer c.clientMutex.Unlock() cfg := c.clientConfig // create openssl arguments args := []string{ "s_client", "-dtls1_2", "-quiet", "-verify_quiet", "-servername=localhost", fmt.Sprintf("-connect=127.0.0.1:%d", c.serverPort), } ciphers := ciphersOpenSSL(cfg) if ciphers != "" { args = append(args, fmt.Sprintf("-cipher=%s", ciphers)) } // psk arguments if cfg.PSK != nil { psk, err := cfg.PSK(nil) if err != nil { c.errChan <- err return } args = append(args, fmt.Sprintf("-psk=%X", psk)) } // certificate arguments if len(cfg.Certificates) > 0 { // create temporary cert files certPEM, keyPEM, err := writeTempPEM(cfg) if err != nil { c.errChan <- err return } args = append(args, fmt.Sprintf("-CAfile=%s", certPEM), fmt.Sprintf("-cert=%s", certPEM), fmt.Sprintf("-key=%s", keyPEM)) defer func() { _ = os.Remove(certPEM) _ = os.Remove(keyPEM) }() } if !cfg.InsecureSkipVerify { args = append(args, "-verify_return_error") } // launch command // #nosec G204 cmd := exec.Command("openssl", args...) var inner net.Conn inner, c.clientConn = net.Pipe() cmd.Stdin = inner cmd.Stdout = inner cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { c.errChan <- err _ = inner.Close() return } simpleReadWrite(c.errChan, c.clientChan, c.clientConn, c.messageRecvCount) c.clientDone <- cmd.Process.Kill() close(c.clientDone) } func ciphersOpenSSL(cfg *dtls.Config) string { // See https://tls.mbed.org/supported-ssl-ciphersuites translate := map[dtls.CipherSuiteID]string{ dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM: "ECDHE-ECDSA-AES128-CCM", dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8: "ECDHE-ECDSA-AES128-CCM8", dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: "ECDHE-ECDSA-AES128-GCM-SHA256", dtls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: "ECDHE-ECDSA-AES256-GCM-SHA384", dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: "ECDHE-RSA-AES128-GCM-SHA256", dtls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: "ECDHE-RSA-AES256-GCM-SHA384", dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: "ECDHE-ECDSA-AES256-SHA", dtls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: "ECDHE-RSA-AES256-SHA", dtls.TLS_PSK_WITH_AES_128_CCM: "PSK-AES128-CCM", dtls.TLS_PSK_WITH_AES_128_CCM_8: "PSK-AES128-CCM8", dtls.TLS_PSK_WITH_AES_256_CCM_8: "PSK-AES256-CCM8", dtls.TLS_PSK_WITH_AES_128_GCM_SHA256: "PSK-AES128-GCM-SHA256", dtls.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256: "ECDHE-PSK-AES128-CBC-SHA256", } var ciphers []string for _, c := range cfg.CipherSuites { if text, ok := translate[c]; ok { ciphers = append(ciphers, text) } } return strings.Join(ciphers, ";") } func writeTempPEM(cfg *dtls.Config) (string, string, error) { certOut, err := ioutil.TempFile("", "cert.pem") if err != nil { return "", "", fmt.Errorf("failed to create temporary file: %w", err) } keyOut, err := ioutil.TempFile("", "key.pem") if err != nil { return "", "", fmt.Errorf("failed to create temporary file: %w", err) } cert := cfg.Certificates[0] derBytes := cert.Certificate[0] if err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { return "", "", fmt.Errorf("failed to write data to cert.pem: %w", err) } if err = certOut.Close(); err != nil { return "", "", fmt.Errorf("error closing cert.pem: %w", err) } priv := cert.PrivateKey var privBytes []byte privBytes, err = x509.MarshalPKCS8PrivateKey(priv) if err != nil { return "", "", fmt.Errorf("unable to marshal private key: %w", err) } if err = pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { return "", "", fmt.Errorf("failed to write data to key.pem: %w", err) } if err = keyOut.Close(); err != nil { return "", "", fmt.Errorf("error closing key.pem: %w", err) } return certOut.Name(), keyOut.Name(), nil } func minimumOpenSSLVersion(t *testing.T) bool { t.Helper() cmd := exec.Command("openssl", "version") allOut, err := cmd.CombinedOutput() if err != nil { t.Log("Cannot determine OpenSSL version: ", err) return false } verMatch := regexp.MustCompile(`(?i)^OpenSSL\s(?P(\d+\.)?(\d+\.)?(\*|\d+)(\w)?).+$`) match := verMatch.FindStringSubmatch(strings.TrimSpace(string(allOut))) params := map[string]string{} for i, name := range verMatch.SubexpNames() { if i > 0 && i <= len(match) { params[name] = match[i] } } var ver string if val, ok := params["version"]; !ok { t.Log("Could not extract OpenSSL version") return false } else { ver = val } cmp := strings.Compare(ver, "3.0.0") if cmp == -1 { return false } return true } func TestPionOpenSSLE2ESimple(t *testing.T) { t.Run("OpenSSLServer", func(t *testing.T) { testPionE2ESimple(t, serverOpenSSL, clientPion) }) t.Run("OpenSSLClient", func(t *testing.T) { testPionE2ESimple(t, serverPion, clientOpenSSL) }) } func TestPionOpenSSLE2ESimplePSK(t *testing.T) { t.Run("OpenSSLServer", func(t *testing.T) { testPionE2ESimplePSK(t, serverOpenSSL, clientPion) }) t.Run("OpenSSLClient", func(t *testing.T) { testPionE2ESimplePSK(t, serverPion, clientOpenSSL) }) } func TestPionOpenSSLE2EMTUs(t *testing.T) { t.Run("OpenSSLServer", func(t *testing.T) { testPionE2EMTUs(t, serverOpenSSL, clientPion) }) t.Run("OpenSSLClient", func(t *testing.T) { testPionE2EMTUs(t, serverPion, clientOpenSSL) }) } func TestPionOpenSSLE2ESimpleED25519(t *testing.T) { t.Run("OpenSSLServer", func(t *testing.T) { if !minimumOpenSSLVersion(t) { t.Skip("Cannot use OpenSSL < 3.0 as a DTLS server with ED25519 keys") } testPionE2ESimpleED25519(t, serverOpenSSL, clientPion) }) t.Run("OpenSSLClient", func(t *testing.T) { testPionE2ESimpleED25519(t, serverPion, clientOpenSSL) }) } func TestPionOpenSSLE2ESimpleED25519ClientCert(t *testing.T) { t.Run("OpenSSLServer", func(t *testing.T) { if !minimumOpenSSLVersion(t) { t.Skip("Cannot use OpenSSL < 3.0 as a DTLS server with ED25519 keys") } testPionE2ESimpleED25519ClientCert(t, serverOpenSSL, clientPion) }) t.Run("OpenSSLClient", func(t *testing.T) { testPionE2ESimpleED25519ClientCert(t, serverPion, clientOpenSSL) }) } func TestPionOpenSSLE2ESimpleECDSAClientCert(t *testing.T) { t.Run("OpenSSLServer", func(t *testing.T) { testPionE2ESimpleECDSAClientCert(t, serverOpenSSL, clientPion) }) t.Run("OpenSSLClient", func(t *testing.T) { testPionE2ESimpleECDSAClientCert(t, serverPion, clientOpenSSL) }) } golang-github-pion-dtls-v3-3.0.7/e2e/e2e_test.go000066400000000000000000000450101507057460300212300ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package e2e import ( "context" "crypto/ed25519" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "errors" "fmt" "io" "net" "sync" "sync/atomic" "testing" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/pkg/crypto/selfsign" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/assert" ) const ( testMessage = "Hello World" testTimeLimit = 5 * time.Second messageRetry = 200 * time.Millisecond ) var ( errServerTimeout = errors.New("waiting on serverReady err: timeout") errHookCiphersFailed = errors.New("hook failed to modify cipherlist") errHookAPLNFailed = errors.New("hook failed to modify APLN extension") ) func randomPort(tb testing.TB) int { tb.Helper() conn, err := net.ListenPacket("udp4", "127.0.0.1:0") assert.NoError(tb, err, "failed to pick port") defer func() { _ = conn.Close() }() switch addr := conn.LocalAddr().(type) { case *net.UDPAddr: return addr.Port default: assert.Fail(tb, "failed to acquire port", "unknown addr type %T", addr) return 0 } } func simpleReadWrite(errChan chan error, outChan chan string, conn io.ReadWriter, messageRecvCount *uint64) { go func() { buffer := make([]byte, 8192) n, err := conn.Read(buffer) if err != nil { errChan <- err return } outChan <- string(buffer[:n]) atomic.AddUint64(messageRecvCount, 1) }() for { if atomic.LoadUint64(messageRecvCount) == 2 { break } else if _, err := conn.Write([]byte(testMessage)); err != nil { errChan <- err break } time.Sleep(messageRetry) } } type comm struct { ctx context.Context //nolint:containedctx clientConfig, serverConfig *dtls.Config serverPort int messageRecvCount *uint64 // Counter to make sure both sides got a message clientMutex *sync.Mutex clientConn net.Conn clientDone chan error serverMutex *sync.Mutex serverConn net.Conn serverListener net.Listener serverReady chan struct{} serverDone chan error errChan chan error clientChan chan string serverChan chan string client func(*comm) server func(*comm) } func newComm( ctx context.Context, clientConfig, serverConfig *dtls.Config, serverPort int, server, client func(*comm), ) *comm { messageRecvCount := uint64(0) com := &comm{ ctx: ctx, clientConfig: clientConfig, serverConfig: serverConfig, serverPort: serverPort, messageRecvCount: &messageRecvCount, clientMutex: &sync.Mutex{}, serverMutex: &sync.Mutex{}, serverReady: make(chan struct{}), serverDone: make(chan error), clientDone: make(chan error), errChan: make(chan error), clientChan: make(chan string), serverChan: make(chan string), server: server, client: client, } return com } func (c *comm) assert(t *testing.T) { //nolint:cyclop t.Helper() // DTLS Client go c.client(c) // DTLS Server go c.server(c) defer func() { if c.clientConn != nil { assert.NoError(t, c.clientConn.Close()) } if c.serverConn != nil { assert.NoError(t, c.serverConn.Close()) } if c.serverListener != nil { assert.NoError(t, c.serverListener.Close()) } }() func() { seenClient, seenServer := false, false for { select { case err := <-c.errChan: assert.NoError(t, err) case <-time.After(testTimeLimit): assert.Failf(t, "Test timeout", "seenClient %t seenServer %t", seenClient, seenServer) case clientMsg := <-c.clientChan: assert.Equal(t, testMessage, clientMsg) seenClient = true if seenClient && seenServer { return } case serverMsg := <-c.serverChan: assert.Equal(t, testMessage, serverMsg) seenServer = true if seenClient && seenServer { return } } } }() } func (c *comm) cleanup(t *testing.T) { t.Helper() clientDone, serverDone := false, false for { select { case err := <-c.clientDone: assert.NoError(t, err) clientDone = true if clientDone && serverDone { return } case err := <-c.serverDone: assert.NoError(t, err) serverDone = true if clientDone && serverDone { return } case <-time.After(testTimeLimit): assert.Fail(t, "Test timeout waiting for server shutdown") } } } func clientPion(c *comm) { //nolint:varnamelen select { case <-c.serverReady: // OK case <-time.After(time.Second): c.errChan <- errServerTimeout } c.clientMutex.Lock() defer c.clientMutex.Unlock() conn, err := dtls.Dial("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort}, c.clientConfig, ) if err != nil { c.errChan <- err return } if err := conn.HandshakeContext(c.ctx); err != nil { c.errChan <- err return } c.clientConn = conn simpleReadWrite(c.errChan, c.clientChan, c.clientConn, c.messageRecvCount) c.clientDone <- nil close(c.clientDone) } func serverPion(c *comm) { //nolint:varnamelen c.serverMutex.Lock() defer c.serverMutex.Unlock() var err error c.serverListener, err = dtls.Listen("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort}, c.serverConfig, ) if err != nil { c.errChan <- err return } c.serverReady <- struct{}{} c.serverConn, err = c.serverListener.Accept() if err != nil { c.errChan <- err return } dtlsConn, ok := c.serverConn.(*dtls.Conn) if ok { if err := dtlsConn.HandshakeContext(c.ctx); err != nil { c.errChan <- err return } } simpleReadWrite(c.errChan, c.serverChan, c.serverConn, c.messageRecvCount) c.serverDone <- nil close(c.serverDone) } type dtlsConfOpts func(*dtls.Config) func withConnectionIDGenerator(g func() []byte) dtlsConfOpts { return func(c *dtls.Config) { c.ConnectionIDGenerator = g } } // Simple DTLS Client/Server can communicate // - Assert that you can send messages both ways // - Assert that Close() on both ends work // - Assert that no Goroutines are leaked func testPionE2ESimple(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { t.Helper() lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() for _, cipherSuite := range []dtls.CipherSuiteID{ dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, dtls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, } { cipherSuite := cipherSuite t.Run(cipherSuite.String(), func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") assert.NoError(t, err) cfg := &dtls.Config{ Certificates: []tls.Certificate{cert}, CipherSuites: []dtls.CipherSuiteID{cipherSuite}, InsecureSkipVerify: true, } for _, o := range opts { o(cfg) } serverPort := randomPort(t) comm := newComm(ctx, cfg, cfg, serverPort, server, client) defer comm.cleanup(t) comm.assert(t) }) } } func testPionE2ESimplePSK(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { t.Helper() lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() for _, cipherSuite := range []dtls.CipherSuiteID{ dtls.TLS_PSK_WITH_AES_128_CCM, dtls.TLS_PSK_WITH_AES_128_CCM_8, dtls.TLS_PSK_WITH_AES_256_CCM_8, dtls.TLS_PSK_WITH_AES_128_GCM_SHA256, dtls.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256, } { cipherSuite := cipherSuite t.Run(cipherSuite.String(), func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() cfg := &dtls.Config{ PSK: func([]byte) ([]byte, error) { return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: []byte{0x01, 0x02, 0x03, 0x04, 0x05}, CipherSuites: []dtls.CipherSuiteID{cipherSuite}, } for _, o := range opts { o(cfg) } serverPort := randomPort(t) comm := newComm(ctx, cfg, cfg, serverPort, server, client) defer comm.cleanup(t) comm.assert(t) }) } } func testPionE2EMTUs(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { t.Helper() lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() for _, mtu := range []int{ 10000, 1000, 100, } { mtu := mtu t.Run(fmt.Sprintf("MTU%d", mtu), func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") assert.NoError(t, err) cfg := &dtls.Config{ Certificates: []tls.Certificate{cert}, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, InsecureSkipVerify: true, MTU: mtu, } for _, o := range opts { o(cfg) } serverPort := randomPort(t) comm := newComm(ctx, cfg, cfg, serverPort, server, client) defer comm.cleanup(t) comm.assert(t) }) } } func testPionE2ESimpleED25519(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { t.Helper() lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() for _, cipherSuite := range []dtls.CipherSuiteID{ dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM, dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8, dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, dtls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, } { cipherSuite := cipherSuite t.Run(cipherSuite.String(), func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _, key, err := ed25519.GenerateKey(rand.Reader) assert.NoError(t, err) cert, err := selfsign.SelfSign(key) assert.NoError(t, err) cfg := &dtls.Config{ Certificates: []tls.Certificate{cert}, CipherSuites: []dtls.CipherSuiteID{cipherSuite}, InsecureSkipVerify: true, } for _, o := range opts { o(cfg) } serverPort := randomPort(t) comm := newComm(ctx, cfg, cfg, serverPort, server, client) defer comm.cleanup(t) comm.assert(t) }) } } func testPionE2ESimpleED25519ClientCert(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { t.Helper() lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _, skey, err := ed25519.GenerateKey(rand.Reader) assert.NoError(t, err) scert, err := selfsign.SelfSign(skey) assert.NoError(t, err) _, ckey, err := ed25519.GenerateKey(rand.Reader) assert.NoError(t, err) ccert, err := selfsign.SelfSign(ckey) assert.NoError(t, err) scfg := &dtls.Config{ Certificates: []tls.Certificate{scert}, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, ClientAuth: dtls.RequireAnyClientCert, } ccfg := &dtls.Config{ Certificates: []tls.Certificate{ccert}, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, InsecureSkipVerify: true, } for _, o := range opts { o(scfg) o(ccfg) } serverPort := randomPort(t) comm := newComm(ctx, ccfg, scfg, serverPort, server, client) defer comm.cleanup(t) comm.assert(t) } func testPionE2ESimpleECDSAClientCert(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { t.Helper() lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() scert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) ccert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) clientCAs := x509.NewCertPool() caCert, err := x509.ParseCertificate(ccert.Certificate[0]) assert.NoError(t, err) clientCAs.AddCert(caCert) scfg := &dtls.Config{ ClientCAs: clientCAs, Certificates: []tls.Certificate{scert}, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, ClientAuth: dtls.RequireAnyClientCert, } ccfg := &dtls.Config{ Certificates: []tls.Certificate{ccert}, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, InsecureSkipVerify: true, } for _, o := range opts { o(scfg) o(ccfg) } serverPort := randomPort(t) comm := newComm(ctx, ccfg, scfg, serverPort, server, client) defer comm.cleanup(t) comm.assert(t) } func testPionE2ESimpleRSAClientCert(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { t.Helper() lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() spriv, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) scert, err := selfsign.SelfSign(spriv) assert.NoError(t, err) cpriv, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) ccert, err := selfsign.SelfSign(cpriv) assert.NoError(t, err) scfg := &dtls.Config{ Certificates: []tls.Certificate{scert}, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, ClientAuth: dtls.RequireAnyClientCert, } ccfg := &dtls.Config{ Certificates: []tls.Certificate{ccert}, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, InsecureSkipVerify: true, } for _, o := range opts { o(scfg) o(ccfg) } serverPort := randomPort(t) comm := newComm(ctx, ccfg, scfg, serverPort, server, client) defer comm.cleanup(t) comm.assert(t) } func testPionE2ESimpleClientHelloHook(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { t.Helper() lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() t.Run("ClientHello hook", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") assert.NoError(t, err) modifiedCipher := dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA supportedList := []dtls.CipherSuiteID{ dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM, modifiedCipher, } ccfg := &dtls.Config{ Certificates: []tls.Certificate{cert}, VerifyConnection: func(s *dtls.State) error { if s.CipherSuiteID != modifiedCipher { return errHookCiphersFailed } return nil }, CipherSuites: supportedList, ClientHelloMessageHook: func(ch handshake.MessageClientHello) handshake.Message { ch.CipherSuiteIDs = []uint16{uint16(modifiedCipher)} return &ch }, InsecureSkipVerify: true, } scfg := &dtls.Config{ Certificates: []tls.Certificate{cert}, CipherSuites: supportedList, InsecureSkipVerify: true, } for _, o := range opts { o(ccfg) o(scfg) } serverPort := randomPort(t) comm := newComm(ctx, ccfg, scfg, serverPort, server, client) defer comm.cleanup(t) comm.assert(t) }) } func testPionE2ESimpleServerHelloHook(t *testing.T, server, client func(*comm), opts ...dtlsConfOpts) { t.Helper() lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() t.Run("ServerHello hook", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") assert.NoError(t, err) supportedList := []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM} apln := "APLN" ccfg := &dtls.Config{ Certificates: []tls.Certificate{cert}, VerifyConnection: func(s *dtls.State) error { if s.NegotiatedProtocol != apln { return errHookAPLNFailed } return nil }, CipherSuites: supportedList, InsecureSkipVerify: true, } scfg := &dtls.Config{ Certificates: []tls.Certificate{cert}, CipherSuites: supportedList, ServerHelloMessageHook: func(sh handshake.MessageServerHello) handshake.Message { sh.Extensions = append(sh.Extensions, &extension.ALPN{ ProtocolNameList: []string{apln}, }) return &sh }, InsecureSkipVerify: true, } for _, o := range opts { o(ccfg) o(scfg) } serverPort := randomPort(t) comm := newComm(ctx, ccfg, scfg, serverPort, server, client) defer comm.cleanup(t) comm.assert(t) }) } func TestPionE2ESimple(t *testing.T) { testPionE2ESimple(t, serverPion, clientPion) } func TestPionE2ESimplePSK(t *testing.T) { testPionE2ESimplePSK(t, serverPion, clientPion) } func TestPionE2EMTUs(t *testing.T) { testPionE2EMTUs(t, serverPion, clientPion) } func TestPionE2ESimpleED25519(t *testing.T) { testPionE2ESimpleED25519(t, serverPion, clientPion) } func TestPionE2ESimpleED25519ClientCert(t *testing.T) { testPionE2ESimpleED25519ClientCert(t, serverPion, clientPion) } func TestPionE2ESimpleECDSAClientCert(t *testing.T) { testPionE2ESimpleECDSAClientCert(t, serverPion, clientPion) } func TestPionE2ESimpleRSAClientCert(t *testing.T) { testPionE2ESimpleRSAClientCert(t, serverPion, clientPion) } func TestPionE2ESimpleCID(t *testing.T) { testPionE2ESimple(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8))) } func TestPionE2ESimplePSKCID(t *testing.T) { testPionE2ESimplePSK(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8))) } func TestPionE2EMTUsCID(t *testing.T) { testPionE2EMTUs(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8))) } func TestPionE2ESimpleED25519CID(t *testing.T) { testPionE2ESimpleED25519(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8))) } func TestPionE2ESimpleED25519ClientCertCID(t *testing.T) { testPionE2ESimpleED25519ClientCert(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8))) } func TestPionE2ESimpleECDSAClientCertCID(t *testing.T) { testPionE2ESimpleECDSAClientCert(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8))) } func TestPionE2ESimpleRSAClientCertCID(t *testing.T) { testPionE2ESimpleRSAClientCert(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8))) } func TestPionE2ESimpleClientHelloHook(t *testing.T) { testPionE2ESimpleClientHelloHook(t, serverPion, clientPion) } func TestPionE2ESimpleServerHelloHook(t *testing.T) { testPionE2ESimpleServerHelloHook(t, serverPion, clientPion) } golang-github-pion-dtls-v3-3.0.7/errors.go000066400000000000000000000205471507057460300203670ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "errors" "fmt" "io" "net" "os" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" ) // Typed errors. var ( ErrConnClosed = &FatalError{Err: errors.New("conn is closed")} //nolint:goerr113 errDeadlineExceeded = &TimeoutError{Err: fmt.Errorf("read/write timeout: %w", context.DeadlineExceeded)} errInvalidContentType = &TemporaryError{Err: errors.New("invalid content type")} //nolint:goerr113 //nolint:goerr113 errBufferTooSmall = &TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113 errContextUnsupported = &TemporaryError{Err: errors.New("context is not supported for ExportKeyingMaterial")} //nolint:goerr113 errHandshakeInProgress = &TemporaryError{Err: errors.New("handshake is in progress")} //nolint:goerr113 errReservedExportKeyingMaterial = &TemporaryError{ Err: errors.New("ExportKeyingMaterial can not be used with a reserved label"), } //nolint:goerr113 errApplicationDataEpochZero = &TemporaryError{Err: errors.New("ApplicationData with epoch of 0")} //nolint:goerr113 errUnhandledContextType = &TemporaryError{Err: errors.New("unhandled contentType")} //nolint:goerr113 errCertificateVerifyNoCertificate = &FatalError{ Err: errors.New("client sent certificate verify but we have no certificate to verify"), } //nolint:goerr113 errCipherSuiteNoIntersection = &FatalError{Err: errors.New("client+server do not support any shared cipher suites")} //nolint:goerr113 errClientCertificateNotVerified = &FatalError{Err: errors.New("client sent certificate but did not verify it")} //nolint:goerr113 errClientCertificateRequired = &FatalError{Err: errors.New("server required client verification, but got none")} //nolint:goerr113 errClientNoMatchingSRTPProfile = &FatalError{Err: errors.New("server responded with SRTP Profile we do not support")} //nolint:goerr113 errClientRequiredButNoServerEMS = &FatalError{ Err: errors.New("client required Extended Master Secret extension, but server does not support it"), } //nolint:goerr113 errCookieMismatch = &FatalError{Err: errors.New("client+server cookie does not match")} //nolint:goerr113 errIdentityNoPSK = &FatalError{Err: errors.New("PSK Identity Hint provided but PSK is nil")} //nolint:goerr113 errInvalidCertificate = &FatalError{Err: errors.New("no certificate provided")} //nolint:goerr113 errInvalidCipherSuite = &FatalError{Err: errors.New("invalid or unknown cipher suite")} //nolint:goerr113 errInvalidECDSASignature = &FatalError{Err: errors.New("ECDSA signature contained zero or negative values")} //nolint:goerr113 errInvalidPrivateKey = &FatalError{Err: errors.New("invalid private key type")} //nolint:goerr113 errInvalidSignatureAlgorithm = &FatalError{Err: errors.New("invalid signature algorithm")} //nolint:goerr113 errKeySignatureMismatch = &FatalError{Err: errors.New("expected and actual key signature do not match")} //nolint:goerr113 errNilNextConn = &FatalError{Err: errors.New("Conn can not be created with a nil nextConn")} //nolint:goerr113 errNoAvailableCipherSuites = &FatalError{ Err: errors.New("connection can not be created, no CipherSuites satisfy this Config"), } //nolint:goerr113 errNoAvailablePSKCipherSuite = &FatalError{ Err: errors.New("connection can not be created, pre-shared key present but no compatible CipherSuite"), } //nolint:goerr113 errNoAvailableCertificateCipherSuite = &FatalError{ Err: errors.New("connection can not be created, certificate present but no compatible CipherSuite"), } //nolint:goerr113 errNoAvailableSignatureSchemes = &FatalError{ Err: errors.New("connection can not be created, no SignatureScheme satisfy this Config"), } //nolint:goerr113 errNoCertificates = &FatalError{Err: errors.New("no certificates configured")} //nolint:goerr113 errNoConfigProvided = &FatalError{Err: errors.New("no config provided")} //nolint:goerr113 errNoSupportedEllipticCurves = &FatalError{ Err: errors.New("client requested zero or more elliptic curves that are not supported by the server"), } //nolint:goerr113 errUnsupportedProtocolVersion = &FatalError{Err: errors.New("unsupported protocol version")} //nolint:goerr113 errPSKAndIdentityMustBeSetForClient = &FatalError{ Err: errors.New("PSK and PSK Identity Hint must both be set for client"), } //nolint:goerr113 errRequestedButNoSRTPExtension = &FatalError{ Err: errors.New("SRTP support was requested but server did not respond with use_srtp extension"), } //nolint:goerr113 errServerNoMatchingSRTPProfile = &FatalError{Err: errors.New("client requested SRTP but we have no matching profiles")} //nolint:goerr113 errServerRequiredButNoClientEMS = &FatalError{ Err: errors.New("server requires the Extended Master Secret extension, but the client does not support it"), } //nolint:goerr113 errVerifyDataMismatch = &FatalError{Err: errors.New("expected and actual verify data does not match")} //nolint:goerr113 errNotAcceptableCertificateChain = &FatalError{Err: errors.New("certificate chain is not signed by an acceptable CA")} //nolint:goerr113 errInvalidFlight = &InternalError{Err: errors.New("invalid flight number")} //nolint:goerr113 errKeySignatureGenerateUnimplemented = &InternalError{ Err: errors.New("unable to generate key signature, unimplemented"), } //nolint:goerr113 errKeySignatureVerifyUnimplemented = &InternalError{Err: errors.New("unable to verify key signature, unimplemented")} //nolint:goerr113 errLengthMismatch = &InternalError{Err: errors.New("data length and declared length do not match")} //nolint:goerr113 errSequenceNumberOverflow = &InternalError{Err: errors.New("sequence number overflow")} //nolint:goerr113 errInvalidFSMTransition = &InternalError{Err: errors.New("invalid state machine transition")} //nolint:goerr113 errFailedToAccessPoolReadBuffer = &InternalError{Err: errors.New("failed to access pool read buffer")} //nolint:goerr113 errFragmentBufferOverflow = &InternalError{Err: errors.New("fragment buffer overflow")} ) // FatalError indicates that the DTLS connection is no longer available. // It is mainly caused by wrong configuration of server or client. type FatalError = protocol.FatalError // InternalError indicates and internal error caused by the implementation, // and the DTLS connection is no longer available. // It is mainly caused by bugs or tried to use unimplemented features. type InternalError = protocol.InternalError // TemporaryError indicates that the DTLS connection is still available, but the request was failed temporary. type TemporaryError = protocol.TemporaryError // TimeoutError indicates that the request was timed out. type TimeoutError = protocol.TimeoutError // HandshakeError indicates that the handshake failed. type HandshakeError = protocol.HandshakeError // errInvalidCipherSuite indicates an attempt at using an unsupported cipher suite. type invalidCipherSuiteError struct { id CipherSuiteID } func (e *invalidCipherSuiteError) Error() string { return fmt.Sprintf("CipherSuite with id(%d) is not valid", e.id) } func (e *invalidCipherSuiteError) Is(err error) bool { var other *invalidCipherSuiteError if errors.As(err, &other) { return e.id == other.id } return false } // errAlert wraps DTLS alert notification as an error. type alertError struct { *alert.Alert } func (e *alertError) Error() string { return fmt.Sprintf("alert: %s", e.Alert.String()) } func (e *alertError) IsFatalOrCloseNotify() bool { return e.Level == alert.Fatal || e.Description == alert.CloseNotify } func (e *alertError) Is(err error) bool { var other *alertError if errors.As(err, &other) { return e.Level == other.Level && e.Description == other.Description } return false } // netError translates an error from underlying Conn to corresponding net.Error. func netError(err error) error { switch { case errors.Is(err, io.EOF), errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): // Return io.EOF and context errors as is. return err } var ( ne net.Error opError *net.OpError se *os.SyscallError ) if errors.As(err, &opError) { //nolint:nestif if errors.As(opError, &se) { if se.Timeout() { return &TimeoutError{Err: err} } if isOpErrorTemporary(se) { return &TemporaryError{Err: err} } } } if errors.As(err, &ne) { return err } return &FatalError{Err: err} } golang-github-pion-dtls-v3-3.0.7/errors_errno.go000066400000000000000000000012321507057460300215620ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build aix || darwin || dragonfly || freebsd || linux || nacl || nacljs || netbsd || openbsd || solaris || windows // +build aix darwin dragonfly freebsd linux nacl nacljs netbsd openbsd solaris windows // For systems having syscall.Errno. // Update build targets by following command: // $ grep -R ECONN $(go env GOROOT)/src/syscall/zerrors_*.go \ // | tr "." "_" | cut -d"_" -f"2" | sort | uniq package dtls import ( "errors" "os" "syscall" ) func isOpErrorTemporary(err *os.SyscallError) bool { return errors.Is(err.Err, syscall.ECONNREFUSED) } golang-github-pion-dtls-v3-3.0.7/errors_errno_test.go000066400000000000000000000023551507057460300226300ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build aix || darwin || dragonfly || freebsd || linux || nacl || nacljs || netbsd || openbsd || solaris || windows // +build aix darwin dragonfly freebsd linux nacl nacljs netbsd openbsd solaris windows // For systems having syscall.Errno. // The build target must be same as errors_errno.go. package dtls import ( "net" "testing" "github.com/stretchr/testify/assert" ) func TestErrorsTemporary(t *testing.T) { // Allocate a UDP port no one is listening on. addrListen, err := net.ResolveUDPAddr("udp", "localhost:0") assert.NoError(t, err) listener, err := net.ListenUDP("udp", addrListen) assert.NoError(t, err) raddr, ok := listener.LocalAddr().(*net.UDPAddr) assert.True(t, ok) assert.NoError(t, listener.Close()) // Server is not listening. conn, errDial := net.DialUDP("udp", nil, raddr) assert.NoError(t, errDial) _, _ = conn.Write([]byte{0x00}) // trigger _, err = conn.Read(make([]byte, 10)) _ = conn.Close() if err == nil { t.Skip("ECONNREFUSED is not set by system") } var ne net.Error assert.ErrorAs(t, netError(err), &ne) assert.False(t, ne.Timeout()) assert.True(t, ne.Temporary()) //nolint:staticcheck } golang-github-pion-dtls-v3-3.0.7/errors_noerrno.go000066400000000000000000000010141507057460300221150ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !nacl && !nacljs && !netbsd && !openbsd && !solaris && !windows // +build !aix,!darwin,!dragonfly,!freebsd,!linux,!nacl,!nacljs,!netbsd,!openbsd,!solaris,!windows // For systems without syscall.Errno. // Build targets must be inverse of errors_errno.go package dtls import ( "os" ) func isOpErrorTemporary(err *os.SyscallError) bool { return false } golang-github-pion-dtls-v3-3.0.7/errors_test.go000066400000000000000000000036731507057460300214270ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "errors" "fmt" "net" "testing" "github.com/stretchr/testify/assert" ) var errExample = errors.New("an example error") func TestErrorUnwrap(t *testing.T) { cases := []struct { err error errUnwrapped []error }{ { &FatalError{Err: errExample}, []error{errExample}, }, { &TemporaryError{Err: errExample}, []error{errExample}, }, { &InternalError{Err: errExample}, []error{errExample}, }, { &TimeoutError{Err: errExample}, []error{errExample}, }, { &HandshakeError{Err: errExample}, []error{errExample}, }, } for _, c := range cases { c := c t.Run(fmt.Sprintf("%T", c.err), func(t *testing.T) { err := c.err for _, unwrapped := range c.errUnwrapped { assert.ErrorIs(t, errors.Unwrap(err), unwrapped) } }) } } func TestErrorNetError(t *testing.T) { cases := []struct { err error str string timeout, temporary bool }{ {&FatalError{Err: errExample}, "dtls fatal: an example error", false, false}, {&TemporaryError{Err: errExample}, "dtls temporary: an example error", false, true}, {&InternalError{Err: errExample}, "dtls internal: an example error", false, false}, {&TimeoutError{Err: errExample}, "dtls timeout: an example error", true, true}, {&HandshakeError{Err: errExample}, "handshake error: an example error", false, false}, {&HandshakeError{Err: &TimeoutError{Err: errExample}}, "handshake error: dtls timeout: an example error", true, true}, } for _, testCase := range cases { testCase := testCase t.Run(fmt.Sprintf("%T", testCase.err), func(t *testing.T) { var ne net.Error assert.ErrorAs(t, testCase.err, &ne) assert.Equal(t, testCase.timeout, ne.Timeout()) assert.Equal(t, testCase.temporary, ne.Temporary()) //nolint:staticcheck assert.Equal(t, testCase.str, ne.Error()) }) } } golang-github-pion-dtls-v3-3.0.7/examples/000077500000000000000000000000001507057460300203325ustar00rootroot00000000000000golang-github-pion-dtls-v3-3.0.7/examples/certificates/000077500000000000000000000000001507057460300227775ustar00rootroot00000000000000golang-github-pion-dtls-v3-3.0.7/examples/certificates/README.md000066400000000000000000000024321507057460300242570ustar00rootroot00000000000000# Certificates The certificates in for the examples are generated using the commands shown below. Note that this was run on OpenSSL 1.1.1d, of which the arguments can be found in the [OpenSSL Manpages](https://www.openssl.org/docs/man1.1.1/man1), and is not guaranteed to work on different OpenSSL versions. ```shell # Extensions required for certificate validation. $ EXTFILE='extfile.conf' $ echo 'subjectAltName = IP:127.0.0.1\nbasicConstraints = critical,CA:true' > "${EXTFILE}" # Server. $ SERVER_NAME='server' $ openssl ecparam -name prime256v1 -genkey -noout -out "${SERVER_NAME}.pem" $ openssl req -key "${SERVER_NAME}.pem" -new -sha256 -subj '/C=NL' -out "${SERVER_NAME}.csr" $ openssl x509 -req -in "${SERVER_NAME}.csr" -extfile "${EXTFILE}" -days 365 -signkey "${SERVER_NAME}.pem" -sha256 -out "${SERVER_NAME}.pub.pem" # Client. $ CLIENT_NAME='client' $ openssl ecparam -name prime256v1 -genkey -noout -out "${CLIENT_NAME}.pem" $ openssl req -key "${CLIENT_NAME}.pem" -new -sha256 -subj '/C=NL' -out "${CLIENT_NAME}.csr" $ openssl x509 -req -in "${CLIENT_NAME}.csr" -extfile "${EXTFILE}" -days 365 -CA "${SERVER_NAME}.pub.pem" -CAkey "${SERVER_NAME}.pem" -set_serial '0xabcd' -sha256 -out "${CLIENT_NAME}.pub.pem" # Cleanup. $ rm "${EXTFILE}" "${SERVER_NAME}.csr" "${CLIENT_NAME}.csr" ``` golang-github-pion-dtls-v3-3.0.7/examples/certificates/client.pem000066400000000000000000000005061507057460300247610ustar00rootroot00000000000000SPDX-FileCopyrightText: 2023 The Pion community SPDX-License-Identifier: CC0-1.0 -----BEGIN EC PRIVATE KEY----- MHcCAQEEIGOO78dEAcepxdUIeDzC28jMcFrJr2q7x+UdhgtJ/RS3oAoGCCqGSM49 AwEHoUQDQgAEGLSNxlkJ9mETKI2Hogq3Cyh06pJKA1YMgcKqYKS6yQQlvvk5rU88 +RojFPgXJukymhfIJmw4eGxxEMSjuEZY7w== -----END EC PRIVATE KEY----- golang-github-pion-dtls-v3-3.0.7/examples/certificates/client.pub.pem000066400000000000000000000010701507057460300255430ustar00rootroot00000000000000SPDX-FileCopyrightText: 2023 The Pion community SPDX-License-Identifier: CC0-1.0 -----BEGIN CERTIFICATE----- MIIBLTCB1aADAgECAgMAq80wCgYIKoZIzj0EAwIwDTELMAkGA1UEBhMCTkwwHhcN MjAwMzIwMDk0NjQ0WhcNMjEwMzIwMDk0NjQ0WjANMQswCQYDVQQGEwJOTDBZMBMG ByqGSM49AgEGCCqGSM49AwEHA0IABBi0jcZZCfZhEyiNh6IKtwsodOqSSgNWDIHC qmCkuskEJb75Oa1PPPkaIxT4FybpMpoXyCZsOHhscRDEo7hGWO+jJDAiMA8GA1Ud EQQIMAaHBH8AAAEwDwYDVR0TAQH/BAUwAwEB/zAKBggqhkjOPQQDAgNHADBEAiBx sIkcADN9E60veZOFOeANaRWAiQaLWZfUxqkOmfHztQIgI2CfHMjDQwJZFh35HvFs NOPJj8wxFhqR5pqMF23cgOY= -----END CERTIFICATE----- golang-github-pion-dtls-v3-3.0.7/examples/certificates/server.pem000066400000000000000000000005061507057460300250110ustar00rootroot00000000000000SPDX-FileCopyrightText: 2023 The Pion community SPDX-License-Identifier: CC0-1.0 -----BEGIN EC PRIVATE KEY----- MHcCAQEEIDT8Xyx5RpPP+98ulYZKsvKIVdBUJug/L9H2M8JThv+GoAoGCCqGSM49 AwEHoUQDQgAE6Wf0qQqIb5G7g51P83Dh1Yst52kyntGYz1Bt6S7crpmQFs9ZRZMy bJ6MGIwGcVBMgoL3pfxDKdZ3mnzmoibU0w== -----END EC PRIVATE KEY----- golang-github-pion-dtls-v3-3.0.7/examples/certificates/server.pub.pem000066400000000000000000000011201507057460300255670ustar00rootroot00000000000000SPDX-FileCopyrightText: 2023 The Pion community SPDX-License-Identifier: CC0-1.0 -----BEGIN CERTIFICATE----- MIIBPzCB5qADAgECAhRtzyVTL+9D0KHfbcKYeKckpLVRmTAKBggqhkjOPQQDAjAN MQswCQYDVQQGEwJOTDAeFw0yMDAzMjAwOTQ2NDRaFw0yMTAzMjAwOTQ2NDRaMA0x CzAJBgNVBAYTAk5MMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE6Wf0qQqIb5G7 g51P83Dh1Yst52kyntGYz1Bt6S7crpmQFs9ZRZMybJ6MGIwGcVBMgoL3pfxDKdZ3 mnzmoibU06MkMCIwDwYDVR0RBAgwBocEfwAAATAPBgNVHRMBAf8EBTADAQH/MAoG CCqGSM49BAMCA0gAMEUCIQD000SU+klkNLGvHZcMYNVkCFsImnGKIqPMy3LELSiF 0gIgSGIFkNEIAyNxn44CXZJu3piyz1ouK2fLefDJMYfcXgM= -----END CERTIFICATE----- golang-github-pion-dtls-v3-3.0.7/examples/dial/000077500000000000000000000000001507057460300212435ustar00rootroot00000000000000golang-github-pion-dtls-v3-3.0.7/examples/dial/cid/000077500000000000000000000000001507057460300220025ustar00rootroot00000000000000golang-github-pion-dtls-v3-3.0.7/examples/dial/cid/main.go000066400000000000000000000026361507057460300232640ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package main implements an example DTLS client using a pre-shared key. package main import ( "context" "fmt" "net" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/examples/util" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // // Prepare the configuration of the DTLS connection config := &dtls.Config{ PSK: func(hint []byte) ([]byte, error) { fmt.Printf("Server's hint: %s \n", hint) return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: []byte("Pion DTLS Client"), CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, ConnectionIDGenerator: dtls.OnlySendCIDGenerator(), } // Connect to a DTLS server ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() dtlsConn, err := dtls.Dial("udp", addr, config) util.Check(err) defer func() { util.Check(dtlsConn.Close()) }() if err := dtlsConn.HandshakeContext(ctx); err != nil { fmt.Printf("Failed to handshake with server: %v\n", err) return } fmt.Println("Connected; type 'exit' to shutdown gracefully") // Simulate a chat session util.Chat(dtlsConn) } golang-github-pion-dtls-v3-3.0.7/examples/dial/psk/000077500000000000000000000000001507057460300220405ustar00rootroot00000000000000golang-github-pion-dtls-v3-3.0.7/examples/dial/psk/main.go000066400000000000000000000025231507057460300233150ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package main implements an example DTLS client using a pre-shared key. package main import ( "context" "fmt" "net" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/examples/util" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // // Prepare the configuration of the DTLS connection config := &dtls.Config{ PSK: func(hint []byte) ([]byte, error) { fmt.Printf("Server's hint: %s \n", hint) return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: []byte{}, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, } // Connect to a DTLS server ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() dtlsConn, err := dtls.Dial("udp", addr, config) util.Check(err) defer func() { util.Check(dtlsConn.Close()) }() if err := dtlsConn.HandshakeContext(ctx); err != nil { fmt.Printf("Failed to handshake with server: %v\n", err) return } fmt.Println("Connected; type 'exit' to shutdown gracefully") // Simulate a chat session util.Chat(dtlsConn) } golang-github-pion-dtls-v3-3.0.7/examples/dial/selfsign/000077500000000000000000000000001507057460300230555ustar00rootroot00000000000000golang-github-pion-dtls-v3-3.0.7/examples/dial/selfsign/main.go000066400000000000000000000025761507057460300243420ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package main implements a DTLS client using self-signed certificates. package main import ( "context" "crypto/tls" "fmt" "net" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/examples/util" "github.com/pion/dtls/v3/pkg/crypto/selfsign" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // Generate a certificate and private key to secure the connection certificate, genErr := selfsign.GenerateSelfSigned() util.Check(genErr) // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // // Prepare the configuration of the DTLS connection config := &dtls.Config{ Certificates: []tls.Certificate{certificate}, InsecureSkipVerify: true, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, } // Connect to a DTLS server ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() dtlsConn, err := dtls.Dial("udp", addr, config) util.Check(err) defer func() { util.Check(dtlsConn.Close()) }() if err := dtlsConn.HandshakeContext(ctx); err != nil { fmt.Printf("Failed to handshake with server: %v\n", err) return } fmt.Println("Connected; type 'exit' to shutdown gracefully") // Simulate a chat session util.Chat(dtlsConn) } golang-github-pion-dtls-v3-3.0.7/examples/dial/verify/000077500000000000000000000000001507057460300225475ustar00rootroot00000000000000golang-github-pion-dtls-v3-3.0.7/examples/dial/verify/main.go000066400000000000000000000031241507057460300240220ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package main implements a DTLS client using a client certificate. package main import ( "context" "crypto/tls" "crypto/x509" "fmt" "net" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/examples/util" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // certificate, err := util.LoadKeyAndCertificate("examples/certificates/client.pem", "examples/certificates/client.pub.pem") util.Check(err) rootCertificate, err := util.LoadCertificate("examples/certificates/server.pub.pem") util.Check(err) certPool := x509.NewCertPool() cert, err := x509.ParseCertificate(rootCertificate.Certificate[0]) util.Check(err) certPool.AddCert(cert) // Prepare the configuration of the DTLS connection config := &dtls.Config{ Certificates: []tls.Certificate{certificate}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, RootCAs: certPool, } // Connect to a DTLS server ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() dtlsConn, err := dtls.Dial("udp", addr, config) util.Check(err) defer func() { util.Check(dtlsConn.Close()) }() if err := dtlsConn.HandshakeContext(ctx); err != nil { fmt.Printf("Failed to handshake with server: %v\n", err) return } fmt.Println("Connected; type 'exit' to shutdown gracefully") // Simulate a chat session util.Chat(dtlsConn) } golang-github-pion-dtls-v3-3.0.7/examples/listen/000077500000000000000000000000001507057460300216305ustar00rootroot00000000000000golang-github-pion-dtls-v3-3.0.7/examples/listen/cid/000077500000000000000000000000001507057460300223675ustar00rootroot00000000000000golang-github-pion-dtls-v3-3.0.7/examples/listen/cid/main.go000066400000000000000000000035421507057460300236460ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package main implements a DTLS server using a pre-shared key. package main import ( "context" "fmt" "net" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/examples/util" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // // Prepare the configuration of the DTLS connection config := &dtls.Config{ PSK: func(hint []byte) ([]byte, error) { fmt.Printf("Client's hint: %s \n", hint) return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: []byte("Pion DTLS Server"), CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, ConnectionIDGenerator: dtls.RandomCIDGenerator(8), } // Connect to a DTLS server listener, err := dtls.Listen("udp", addr, config) util.Check(err) defer func() { util.Check(listener.Close()) }() fmt.Println("Listening") // Simulate a chat session hub := util.NewHub() go func() { for { // Wait for a connection. conn, err := listener.Accept() util.Check(err) // defer conn.Close() // TODO: graceful shutdown // `conn` is of type `net.Conn` but may be casted to `dtls.Conn` // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. // Perform the handshake with a 30-second timeout ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) dtlsConn, ok := conn.(*dtls.Conn) if ok { util.Check(dtlsConn.HandshakeContext(ctx)) } cancel() // Register the connection with the chat hub if err == nil { hub.Register(conn) } } }() // Start chatting hub.Chat() } golang-github-pion-dtls-v3-3.0.7/examples/listen/psk/000077500000000000000000000000001507057460300224255ustar00rootroot00000000000000golang-github-pion-dtls-v3-3.0.7/examples/listen/psk/main.go000066400000000000000000000034521507057460300237040ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package main implements a DTLS server using a pre-shared key. package main import ( "context" "fmt" "net" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/examples/util" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // // Prepare the configuration of the DTLS connection config := &dtls.Config{ PSK: func(hint []byte) ([]byte, error) { fmt.Printf("Client's hint: %s \n", hint) return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: []byte("Pion DTLS Server"), CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, } // Connect to a DTLS server listener, err := dtls.Listen("udp", addr, config) util.Check(err) defer func() { util.Check(listener.Close()) }() fmt.Println("Listening") // Simulate a chat session hub := util.NewHub() go func() { for { // Wait for a connection. conn, err := listener.Accept() util.Check(err) // defer conn.Close() // TODO: graceful shutdown // `conn` is of type `net.Conn` but may be casted to `dtls.Conn` // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. // Perform the handshake with a 30-second timeout ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) dtlsConn, ok := conn.(*dtls.Conn) if ok { util.Check(dtlsConn.HandshakeContext(ctx)) } cancel() // Register the connection with the chat hub if err == nil { hub.Register(conn) } } }() // Start chatting hub.Chat() } golang-github-pion-dtls-v3-3.0.7/examples/listen/selfsign/000077500000000000000000000000001507057460300234425ustar00rootroot00000000000000golang-github-pion-dtls-v3-3.0.7/examples/listen/selfsign/main.go000066400000000000000000000034671507057460300247270ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package main implements an example DTLS server using self-signed certificates. package main import ( "context" "crypto/tls" "fmt" "net" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/examples/util" "github.com/pion/dtls/v3/pkg/crypto/selfsign" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // Generate a certificate and private key to secure the connection certificate, genErr := selfsign.GenerateSelfSigned() util.Check(genErr) // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // // Prepare the configuration of the DTLS connection config := &dtls.Config{ Certificates: []tls.Certificate{certificate}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, } // Connect to a DTLS server listener, err := dtls.Listen("udp", addr, config) util.Check(err) defer func() { util.Check(listener.Close()) }() fmt.Println("Listening") // Simulate a chat session hub := util.NewHub() go func() { for { // Wait for a connection. conn, err := listener.Accept() util.Check(err) // defer conn.Close() // TODO: graceful shutdown // `conn` is of type `net.Conn` but may be casted to `dtls.Conn` // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. // Perform the handshake with a 30-second timeout ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) dtlsConn, ok := conn.(*dtls.Conn) if ok { util.Check(dtlsConn.HandshakeContext(ctx)) } cancel() // Register the connection with the chat hub if err == nil { hub.Register(conn) } } }() // Start chatting hub.Chat() } golang-github-pion-dtls-v3-3.0.7/examples/listen/verify-brute-force-protection/000077500000000000000000000000001507057460300275335ustar00rootroot00000000000000golang-github-pion-dtls-v3-3.0.7/examples/listen/verify-brute-force-protection/main.go000066400000000000000000000103511507057460300310060ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package main implements an example DTLS server which verifies client certificates. // It also implements a basic Brute Force Attack protection. package main import ( "context" "crypto/tls" "crypto/x509" "fmt" "net" "sync" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/examples/util" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // // ************ Variables used to implement a basic Brute Force Attack protection ************* var ( attempts = make(map[string]int) // Map of attempts for each IP address. attemptsMutex sync.Mutex // Mutex for the map of attempts. attemptsCleaner = time.Now() // Time to be able to clean the map of attempts every X minutes. ) certificate, err := util.LoadKeyAndCertificate("examples/certificates/server.pem", "examples/certificates/server.pub.pem") util.Check(err) rootCertificate, err := util.LoadCertificate("examples/certificates/server.pub.pem") util.Check(err) certPool := x509.NewCertPool() cert, err := x509.ParseCertificate(rootCertificate.Certificate[0]) util.Check(err) certPool.AddCert(cert) // Prepare the configuration of the DTLS connection config := &dtls.Config{ Certificates: []tls.Certificate{certificate}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, ClientAuth: dtls.RequireAndVerifyClientCert, ClientCAs: certPool, // This function will be called on each connection attempt. OnConnectionAttempt: func(addr net.Addr) error { // *************** Brute Force Attack protection *************** // Check if the IP address is in the map, and if the IP address has exceeded the limit attemptsMutex.Lock() defer attemptsMutex.Unlock() // Here I implement a time cleaner for the map of attempts, every 5 minutes I will // decrement by 1 the number of attempts for each IP address. if time.Now().After(attemptsCleaner.Add(time.Minute * 5)) { attemptsCleaner = time.Now() for k, v := range attempts { if v > 0 { attempts[k]-- } if attempts[k] == 0 { delete(attempts, k) } } } // Check if the IP address is in the map, and the IP address has exceeded the limit (Brute Force Attack protection) attemptIP := addr.(*net.UDPAddr).IP.String() //nolint if attempts[attemptIP] > 10 { return fmt.Errorf("too many attempts from this IP address") //nolint } // Here I increment the number of attempts for this IP address (Brute Force Attack protection) attempts[attemptIP]++ // *************** END Brute Force Attack protection END *************** return nil }, } // Connect to a DTLS server listener, err := dtls.Listen("udp", addr, config) util.Check(err) defer func() { util.Check(listener.Close()) }() fmt.Println("Listening") // Simulate a chat session hub := util.NewHub() go func() { for { // Wait for a connection. conn, err := listener.Accept() util.Check(err) // defer conn.Close() // TODO: graceful shutdown // `conn` is of type `net.Conn` but may be casted to `dtls.Conn` // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. // *************** Brute Force Attack protection *************** // Here I decrease the number of attempts for this IP address attemptsMutex.Lock() attemptIP := conn.(*dtls.Conn).RemoteAddr().(*net.UDPAddr).IP.String() //nolint attempts[attemptIP]-- // If the number of attempts for this IP address is 0, I delete the IP address from the map if attempts[attemptIP] == 0 { delete(attempts, attemptIP) } attemptsMutex.Unlock() // *************** END Brute Force Attack protection END *************** // Perform the handshake with a 30-second timeout ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) dtlsConn, ok := conn.(*dtls.Conn) if ok { util.Check(dtlsConn.HandshakeContext(ctx)) } cancel() // Register the connection with the chat hub hub.Register(conn) } }() // Start chatting hub.Chat() } golang-github-pion-dtls-v3-3.0.7/examples/listen/verify/000077500000000000000000000000001507057460300231345ustar00rootroot00000000000000golang-github-pion-dtls-v3-3.0.7/examples/listen/verify/main.go000066400000000000000000000041231507057460300244070ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package main implements an example DTLS server which verifies client certificates. package main import ( "context" "crypto/tls" "crypto/x509" "fmt" "net" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/examples/util" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // certificate, err := util.LoadKeyAndCertificate("examples/certificates/server.pem", "examples/certificates/server.pub.pem") util.Check(err) rootCertificate, err := util.LoadCertificate("examples/certificates/server.pub.pem") util.Check(err) certPool := x509.NewCertPool() cert, err := x509.ParseCertificate(rootCertificate.Certificate[0]) util.Check(err) certPool.AddCert(cert) // Prepare the configuration of the DTLS connection config := &dtls.Config{ Certificates: []tls.Certificate{certificate}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, ClientAuth: dtls.RequireAndVerifyClientCert, ClientCAs: certPool, } // Connect to a DTLS server listener, err := dtls.Listen("udp", addr, config) util.Check(err) defer func() { util.Check(listener.Close()) }() fmt.Println("Listening") // Simulate a chat session hub := util.NewHub() go func() { for { // Wait for a connection. conn, err := listener.Accept() util.Check(err) // defer conn.Close() // TODO: graceful shutdown // `conn` is of type `net.Conn` but may be casted to `dtls.Conn` // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. // Perform the handshake with a 30-second timeout ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) dtlsConn, ok := conn.(*dtls.Conn) if ok { util.Check(dtlsConn.HandshakeContext(ctx)) } cancel() // Register the connection with the chat hub hub.Register(conn) } }() // Start chatting hub.Chat() } golang-github-pion-dtls-v3-3.0.7/examples/util/000077500000000000000000000000001507057460300213075ustar00rootroot00000000000000golang-github-pion-dtls-v3-3.0.7/examples/util/hub.go000066400000000000000000000031551507057460300224200ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package util import ( "bufio" "fmt" "net" "os" "strings" "sync" ) // Hub is a helper to handle one to many chat. type Hub struct { conns map[string]net.Conn lock sync.RWMutex } // NewHub builds a new hub. func NewHub() *Hub { return &Hub{conns: make(map[string]net.Conn)} } // Register adds a new conn to the Hub. func (h *Hub) Register(conn net.Conn) { fmt.Printf("Connected to %s\n", conn.RemoteAddr()) h.lock.Lock() defer h.lock.Unlock() h.conns[conn.RemoteAddr().String()] = conn go h.readLoop(conn) } func (h *Hub) readLoop(conn net.Conn) { b := make([]byte, bufSize) for { n, err := conn.Read(b) if err != nil { h.unregister(conn) return } fmt.Printf("Got message: %s\n", string(b[:n])) } } func (h *Hub) unregister(conn net.Conn) { h.lock.Lock() defer h.lock.Unlock() delete(h.conns, conn.RemoteAddr().String()) err := conn.Close() if err != nil { fmt.Println("Failed to disconnect", conn.RemoteAddr(), err) } else { fmt.Println("Disconnected ", conn.RemoteAddr()) } } func (h *Hub) broadcast(msg []byte) { h.lock.RLock() defer h.lock.RUnlock() for _, conn := range h.conns { _, err := conn.Write(msg) if err != nil { fmt.Printf("Failed to write message to %s: %v\n", conn.RemoteAddr(), err) } } } // Chat starts the stdin readloop to dispatch messages to the hub. func (h *Hub) Chat() { reader := bufio.NewReader(os.Stdin) for { msg, err := reader.ReadString('\n') Check(err) if strings.TrimSpace(msg) == "exit" { return } h.broadcast([]byte(msg)) } } golang-github-pion-dtls-v3-3.0.7/examples/util/util.go000066400000000000000000000040341507057460300226140ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package util provides auxiliary utilities used in examples package util import ( "bufio" "crypto/tls" "encoding/pem" "errors" "fmt" "io" "net" "os" "path/filepath" "strings" ) const bufSize = 8192 var ( errBlockIsNotCertificate = errors.New("block is not a certificate, unable to load certificates") errNoCertificateFound = errors.New("no certificate found, unable to load certificates") ) // Chat simulates a simple text chat session over the connection. func Chat(conn io.ReadWriter) { go func() { b := make([]byte, bufSize) for { n, err := conn.Read(b) Check(err) fmt.Printf("Got message: %s\n", string(b[:n])) } }() reader := bufio.NewReader(os.Stdin) for { text, err := reader.ReadString('\n') Check(err) if strings.TrimSpace(text) == "exit" { return } _, err = conn.Write([]byte(text)) Check(err) } } // Check is a helper to throw errors in the examples. func Check(err error) { var netError net.Error if errors.As(err, &netError) && netError.Temporary() { //nolint:staticcheck fmt.Printf("Warning: %v\n", err) } else if err != nil { fmt.Printf("error: %v\n", err) panic(err) } } // LoadKeyAndCertificate reads certificates or key from file. func LoadKeyAndCertificate(keyPath string, certificatePath string) (tls.Certificate, error) { return tls.LoadX509KeyPair(certificatePath, keyPath) } // LoadCertificate Load/read certificate(s) from file. func LoadCertificate(path string) (*tls.Certificate, error) { rawData, err := os.ReadFile(filepath.Clean(path)) if err != nil { return nil, err } var certificate tls.Certificate for { block, rest := pem.Decode(rawData) if block == nil { break } if block.Type != "CERTIFICATE" { return nil, errBlockIsNotCertificate } certificate.Certificate = append(certificate.Certificate, block.Bytes) rawData = rest } if len(certificate.Certificate) == 0 { return nil, errNoCertificateFound } return &certificate, nil } golang-github-pion-dtls-v3-3.0.7/flight.go000066400000000000000000000061541507057460300203260ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls /* DTLS messages are grouped into a series of message flights, according to the diagrams below. Although each flight of messages may consist of a number of messages, they should be viewed as monolithic for the purpose of timeout and retransmission. https://tools.ietf.org/html/rfc4347#section-4.2.4 Message flights for full handshake: Client Server ------ ------ Waiting Flight 0 ClientHello --------> Flight 1 <------- HelloVerifyRequest Flight 2 ClientHello --------> Flight 3 ServerHello \ Certificate* \ ServerKeyExchange* Flight 4 CertificateRequest* / <-------- ServerHelloDone / Certificate* \ ClientKeyExchange \ CertificateVerify* Flight 5 [ChangeCipherSpec] / Finished --------> / [ChangeCipherSpec] \ Flight 6 <-------- Finished / Message flights for session-resuming handshake (no cookie exchange): Client Server ------ ------ Waiting Flight 0 ClientHello --------> Flight 1 ServerHello \ [ChangeCipherSpec] Flight 4b <-------- Finished / [ChangeCipherSpec] \ Flight 5b Finished --------> / [ChangeCipherSpec] \ Flight 6 <-------- Finished / */ type flightVal uint8 const ( flight0 flightVal = iota + 1 flight1 flight2 flight3 flight4 flight4b flight5 flight5b flight6 ) func (f flightVal) String() string { //nolint:cyclop switch f { case flight0: return "Flight 0" case flight1: return "Flight 1" case flight2: return "Flight 2" case flight3: return "Flight 3" case flight4: return "Flight 4" case flight4b: return "Flight 4b" case flight5: return "Flight 5" case flight5b: return "Flight 5b" case flight6: return "Flight 6" default: return "Invalid Flight" } } func (f flightVal) isLastSendFlight() bool { return f == flight6 || f == flight5b } func (f flightVal) isLastRecvFlight() bool { return f == flight5 || f == flight4b } golang-github-pion-dtls-v3-3.0.7/flight0handler.go000066400000000000000000000124101507057460300217340ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "crypto/rand" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" ) //nolint:cyclop func flight0Parse( _ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) (flightVal, *alert.Alert, error) { seq, msgs, ok := cache.fullPullMap(0, state.cipherSuite, handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } // Connection Identifiers must be negotiated afresh on session resumption. // https://datatracker.ietf.org/doc/html/rfc9146#name-the-connection_id-extension state.setLocalConnectionID(nil) state.remoteConnectionID = nil state.handshakeRecvSequence = seq var clientHello *handshake.MessageClientHello // Validate type if clientHello, ok = msgs[handshake.TypeClientHello].(*handshake.MessageClientHello); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } if !clientHello.Version.Equal(protocol.Version1_2) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion } state.remoteRandom = clientHello.Random cipherSuites := []CipherSuite{} for _, id := range clientHello.CipherSuiteIDs { if c := cipherSuiteForID(CipherSuiteID(id), cfg.customCipherSuites); c != nil { cipherSuites = append(cipherSuites, c) } } if state.cipherSuite, ok = findMatchingCipherSuite(cipherSuites, cfg.localCipherSuites); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection } for _, val := range clientHello.Extensions { switch ext := val.(type) { case *extension.SupportedEllipticCurves: if len(ext.EllipticCurves) == 0 { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoSupportedEllipticCurves } state.namedCurve = ext.EllipticCurves[0] case *extension.UseSRTP: profile, ok := findMatchingSRTPProfile(ext.ProtectionProfiles, cfg.localSRTPProtectionProfiles) if !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerNoMatchingSRTPProfile } state.setSRTPProtectionProfile(profile) state.remoteSRTPMasterKeyIdentifier = ext.MasterKeyIdentifier case *extension.UseExtendedMasterSecret: if cfg.extendedMasterSecret != DisableExtendedMasterSecret { state.extendedMasterSecret = true } case *extension.ServerName: state.serverName = ext.ServerName // remote server name case *extension.RenegotiationInfo: state.remoteSupportsRenegotiation = true case *extension.ALPN: state.peerSupportedProtocols = ext.ProtocolNameList case *extension.ConnectionID: // Only set connection ID to be sent if server supports connection // IDs. if cfg.connectionIDGenerator != nil { state.remoteConnectionID = ext.CID } } } // If the client doesn't support connection IDs, the server should not // expect one to be sent. if state.remoteConnectionID == nil { state.setLocalConnectionID(nil) } if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerRequiredButNoClientEMS } if state.localKeypair == nil { var err error state.localKeypair, err = elliptic.GenerateKeypair(state.namedCurve) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err } } nextFlight := flight2 if cfg.insecureSkipHelloVerify { nextFlight = flight4 } return handleHelloResume(clientHello.SessionID, state, cfg, nextFlight) } func handleHelloResume( sessionID []byte, state *State, cfg *handshakeConfig, next flightVal, ) (flightVal, *alert.Alert, error) { if len(sessionID) > 0 && cfg.sessionStore != nil { if s, err := cfg.sessionStore.Get(sessionID); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } else if s.ID != nil { cfg.log.Tracef("[handshake] resume session: %x", sessionID) state.SessionID = sessionID state.masterSecret = s.Secret if err := state.initCipherSuite(); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } clientRandom := state.localRandom.MarshalFixed() cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret) return flight4b, nil, nil } } return next, nil, nil } func flight0Generate( _ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig, ) ([]*packet, *alert.Alert, error) { // Initialize if !cfg.insecureSkipHelloVerify { state.cookie = make([]byte, cookieLength) if _, err := rand.Read(state.cookie); err != nil { return nil, nil, err } } var zeroEpoch uint16 state.localEpoch.Store(zeroEpoch) state.remoteEpoch.Store(zeroEpoch) state.namedCurve = defaultNamedCurve if err := state.localRandom.Populate(); err != nil { return nil, nil, err } return nil, nil, nil } golang-github-pion-dtls-v3-3.0.7/flight1handler.go000066400000000000000000000130331507057460300217370ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) func flight1Parse( ctx context.Context, conn flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) (flightVal, *alert.Alert, error) { // HelloVerifyRequest can be skipped by the server, // so allow ServerHello during flight1 also seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, true}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } if _, ok := msgs[handshake.TypeServerHello]; ok { // Flight1 and flight2 were skipped. // Parse as flight3. return flight3Parse(ctx, conn, state, cache, cfg) } if h, ok := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); ok { // DTLS 1.2 clients must not assume that the server will use the protocol version // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1 if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion } state.cookie = append([]byte{}, h.Cookie...) state.handshakeRecvSequence = seq return flight3, nil, nil } return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } //nolint:cyclop func flight1Generate( conn flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig, ) ([]*packet, *alert.Alert, error) { var zeroEpoch uint16 state.localEpoch.Store(zeroEpoch) state.remoteEpoch.Store(zeroEpoch) state.namedCurve = defaultNamedCurve state.cookie = nil if err := state.localRandom.Populate(); err != nil { return nil, nil, err } if cfg.helloRandomBytesGenerator != nil { state.localRandom.RandomBytes = cfg.helloRandomBytesGenerator() } extensions := []extension.Extension{ &extension.SupportedSignatureAlgorithms{ SignatureHashAlgorithms: cfg.localSignatureSchemes, }, &extension.RenegotiationInfo{ RenegotiatedConnection: 0, }, } var setEllipticCurveCryptographyClientHelloExtensions bool for _, c := range cfg.localCipherSuites { if c.ECC() { setEllipticCurveCryptographyClientHelloExtensions = true break } } if setEllipticCurveCryptographyClientHelloExtensions { extensions = append(extensions, []extension.Extension{ &extension.SupportedEllipticCurves{ EllipticCurves: cfg.ellipticCurves, }, &extension.SupportedPointFormats{ PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, }, }...) } if len(cfg.localSRTPProtectionProfiles) > 0 { extensions = append(extensions, &extension.UseSRTP{ ProtectionProfiles: cfg.localSRTPProtectionProfiles, MasterKeyIdentifier: cfg.localSRTPMasterKeyIdentifier, }) } if cfg.extendedMasterSecret == RequestExtendedMasterSecret || cfg.extendedMasterSecret == RequireExtendedMasterSecret { extensions = append(extensions, &extension.UseExtendedMasterSecret{ Supported: true, }) } if len(cfg.serverName) > 0 { extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName}) } if len(cfg.supportedProtocols) > 0 { extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols}) } if cfg.sessionStore != nil { cfg.log.Tracef("[handshake] try to resume session") if s, err := cfg.sessionStore.Get(conn.sessionKey()); err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } else if s.ID != nil { cfg.log.Tracef("[handshake] get saved session: %x", s.ID) state.SessionID = s.ID state.masterSecret = s.Secret } } // If we have a connection ID generator, use it. The CID may be zero length, // in which case we are just requesting that the server send us a CID to // use. if cfg.connectionIDGenerator != nil { state.setLocalConnectionID(cfg.connectionIDGenerator()) // The presence of a generator indicates support for connection IDs. We // use the presence of a non-nil local CID in flight 3 to determine // whether we send a CID in the second ClientHello, so we convert any // nil CID returned by a generator to []byte{}. if state.getLocalConnectionID() == nil { state.setLocalConnectionID([]byte{}) } extensions = append(extensions, &extension.ConnectionID{CID: state.getLocalConnectionID()}) } clientHello := &handshake.MessageClientHello{ Version: protocol.Version1_2, SessionID: state.SessionID, Cookie: state.cookie, Random: state.localRandom, CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites), CompressionMethods: defaultCompressionMethods(), Extensions: extensions, } var content handshake.Handshake if cfg.clientHelloMessageHook != nil { content = handshake.Handshake{Message: cfg.clientHelloMessageHook(*clientHello)} } else { content = handshake.Handshake{Message: clientHello} } return []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &content, }, }, }, nil, nil } golang-github-pion-dtls-v3-3.0.7/flight1handler_test.go000066400000000000000000000345421507057460300230060ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "testing" "time" "github.com/pion/dtls/v3/internal/ciphersuite" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/logging" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/assert" ) type flight1TestMockFlightConn struct{} func (f *flight1TestMockFlightConn) notify(context.Context, alert.Level, alert.Description) error { return nil } func (f *flight1TestMockFlightConn) writePackets(context.Context, []*packet) error { return nil } func (f *flight1TestMockFlightConn) recvHandshake() <-chan recvHandshakeState { return nil } func (f *flight1TestMockFlightConn) setLocalEpoch(uint16) {} func (f *flight1TestMockFlightConn) handleQueuedPackets(context.Context) error { return nil } func (f *flight1TestMockFlightConn) sessionKey() []byte { return nil } type flight1TestMockCipherSuite struct { ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256 t *testing.T } func (f *flight1TestMockCipherSuite) IsInitialized() bool { assert.Fail(f.t, "IsInitialized called with Certificate but not CertificateVerify") return true } // When "server hello" arrives later than "certificate", // "server key exchange", "certificate request", "server hello done", // is it normal for the flight1Parse method to handle it. func TestFlight1_Process_ServerHelloLateArrival(t *testing.T) { //nolint:maintidx // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() mockConn := &flight1TestMockFlightConn{} state := &State{ cipherSuite: &flight1TestMockCipherSuite{t: t}, } cache := newHandshakeCache() cfg := &handshakeConfig{ localSRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AEAD_AES_128_GCM}, localCipherSuites: []CipherSuite{}, } cfg.localCipherSuites = []CipherSuite{cipherSuiteForID(TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, nil)} cfg.log = logging.NewDefaultLoggerFactory().NewLogger("dtls") serverHello := []byte{ 0x02, 0x00, 0x00, 0x62, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x62, 0xfe, 0xfd, 0x07, 0x46, 0xb7, 0xbf, 0xde, 0x78, 0xab, 0x38, 0x69, 0x36, 0x74, 0x10, 0xa6, 0x50, 0x67, 0x7b, 0x4b, 0x85, 0xdf, 0x71, 0x71, 0x62, 0x3a, 0xb1, 0xd7, 0xa4, 0x79, 0x6a, 0x38, 0x13, 0x5e, 0xa1, 0x20, 0xbd, 0x64, 0xaf, 0xb3, 0x36, 0x77, 0x73, 0x8a, 0x62, 0x75, 0xb2, 0x64, 0xbe, 0xf6, 0x2a, 0xb1, 0x6e, 0x7b, 0xf6, 0x00, 0xd6, 0x24, 0xd5, 0xb1, 0x1e, 0x54, 0xa3, 0x76, 0xb3, 0xac, 0x76, 0x8f, 0xc0, 0x2f, 0x00, 0x00, 0x1a, 0xff, 0x01, 0x00, 0x01, 0x00, 0x00, 0x0b, 0x00, 0x04, 0x03, 0x00, 0x01, 0x02, 0x00, 0x0e, 0x00, 0x05, 0x00, 0x02, 0x00, 0x07, 0x00, 0x00, 0x17, 0x00, 0x00, } certificate1 := []byte{ 0x0b, 0x00, 0x05, 0x5b, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x04, 0xe4, 0x00, 0x05, 0x58, 0x00, 0x05, 0x55, 0x30, 0x82, 0x05, 0x51, 0x30, 0x82, 0x04, 0x39, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, 0x0c, 0x56, 0x8b, 0xb4, 0x68, 0xed, 0x70, 0xce, 0xb6, 0x8d, 0x44, 0x65, 0x4b, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b, 0x05, 0x00, 0x30, 0x66, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x42, 0x45, 0x31, 0x19, 0x30, 0x17, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x10, 0x47, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x53, 0x69, 0x67, 0x6e, 0x20, 0x6e, 0x76, 0x2d, 0x73, 0x61, 0x31, 0x3c, 0x30, 0x3a, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x33, 0x47, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x53, 0x69, 0x67, 0x6e, 0x20, 0x4f, 0x72, 0x67, 0x61, 0x6e, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x20, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x20, 0x43, 0x41, 0x20, 0x2d, 0x20, 0x53, 0x48, 0x41, 0x32, 0x35, 0x36, 0x20, 0x2d, 0x20, 0x47, 0x32, 0x30, 0x1e, 0x17, 0x0d, 0x31, 0x37, 0x30, 0x34, 0x32, 0x30, 0x31, 0x31, 0x31, 0x39, 0x35, 0x39, 0x5a, 0x17, 0x0d, 0x31, 0x38, 0x30, 0x34, 0x32, 0x31, 0x31, 0x31, 0x31, 0x39, 0x35, 0x39, 0x5a, 0x30, 0x81, 0x84, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x43, 0x4e, 0x31, 0x12, 0x30, 0x10, 0x06, 0x03, 0x55, 0x04, 0x08, 0x13, 0x09, 0x67, 0x75, 0x61, 0x6e, 0x67, 0x64, 0x6f, 0x6e, 0x67, 0x31, 0x11, 0x30, 0x0f, 0x06, 0x03, 0x55, 0x04, 0x07, 0x13, 0x08, 0x73, 0x68, 0x65, 0x6e, 0x7a, 0x68, 0x65, 0x6e, 0x31, 0x36, 0x30, 0x34, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x2d, 0x54, 0x65, 0x6e, 0x63, 0x65, 0x6e, 0x74, 0x20, 0x54, 0x65, 0x63, 0x68, 0x6e, 0x6f, 0x6c, 0x6f, 0x67, 0x79, 0x20, 0x28, 0x53, 0x68, 0x65, 0x6e, 0x7a, 0x68, 0x65, 0x6e, 0x29, 0x20, 0x43, 0x6f, 0x6d, 0x70, 0x61, 0x6e, 0x79, 0x20, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x65, 0x64, 0x31, 0x16, 0x30, 0x14, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x0d, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2e, 0x71, 0x71, 0x2e, 0x63, 0x6f, 0x6d, 0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05, 0x00, 0x03, 0x82, 0x01, 0x0f, 0x00, 0x30, 0x82, 0x01, 0x0a, 0x02, 0x82, 0x01, 0x01, 0x00, 0xb6, 0x00, 0xa7, 0x09, 0x0a, 0xc4, 0x96, 0x24, 0x72, 0xa0, 0x09, 0xda, 0xac, 0x63, 0xe4, 0x9a, 0xfe, 0x8b, 0x9b, 0x99, 0x8c, 0xe3, 0xab, 0x4b, 0x7c, 0xbd, 0x4f, 0x31, 0x1e, 0x2f, 0xff, 0x34, 0x54, 0xb5, 0xb0, 0x99, 0xcd, 0x00, 0x7c, 0x5b, 0x12, 0x96, 0xfa, 0x9b, 0x6b, 0x79, 0xc7, 0xfb, 0x00, 0x53, 0xaf, 0xb6, 0x00, 0x45, 0x46, 0x20, 0x7d, 0x95, 0xca, 0x86, 0xcc, 0x4b, 0xe8, 0x25, 0x52, 0x5b, 0x9c, 0xe7, 0x58, 0xcd, 0xd0, 0x8f, 0x4a, 0xd8, 0x77, 0x7d, 0x45, 0xa0, 0x70, 0xe8, 0x16, 0x45, 0x23, 0xfb, 0xbc, 0x43, 0x36, 0xdd, 0x5b, 0x8f, 0x01, 0xc3, 0xc0, 0xa2, 0xab, 0x80, 0xf1, 0x97, 0x72, 0x38, 0xab, 0x6f, 0xa1, 0x28, 0x09, 0xdd, 0x31, 0x7e, 0x50, 0xc8, 0x51, 0xde, 0x8d, 0x05, 0xbc, 0x72, 0x79, 0x94, 0x6e, 0xd4, 0xb7, 0xf0, 0x97, 0xd0, 0x76, 0x9c, 0x9d, 0xb4, 0x34, 0xf1, 0x8a, 0x82, 0x20, 0x9b, 0x24, 0x4b, 0x38, 0xc9, 0x63, 0xe6, 0x02, 0xf5, 0xb2, 0x9b, 0x70, 0xa4, 0x97, 0x9f, 0xaa, 0x1f, 0x36, 0x9c, 0xfd, 0x81, 0x93, 0x81, 0xd7, 0x4e, 0xca, 0xd2, 0xa7, 0x7c, 0x29, 0x9d, 0x28, 0xf2, 0x3e, 0x3b, 0xea, 0xe6, 0x22, 0x51, 0x8f, 0x0b, 0xe7, 0x65, 0xa1, 0x28, 0xdd, 0x55, 0x6a, 0x59, 0x53, 0x67, 0xb6, 0xb3, 0xd2, 0x4c, 0x90, 0x69, 0xd1, 0x1e, 0x62, 0xab, 0x33, 0x47, 0x29, 0x45, 0x18, 0x1f, 0xeb, 0x6d, 0x13, 0xb4, 0x61, 0xf5, 0x15, 0x03, 0xf7, 0x4f, 0x9c, 0x4c, 0x2c, 0xae, 0x5e, 0xde, 0xd2, 0x11, 0x32, 0xb5, 0x17, 0xb5, 0xe8, 0xa3, 0xb2, 0x1f, 0xc3, 0x9f, 0x78, 0xa1, 0xf5, 0x80, 0xb4, 0x96, 0x90, 0x6b, 0x77, 0x9e, 0xe9, 0x39, 0x61, 0x2c, 0x18, 0xf5, 0x7b, 0xab, 0x1e, 0x09, 0x88, 0x7d, 0xc3, 0x75, 0x5e, 0x4d, 0xcf, 0xf3, 0x02, 0x03, 0x01, 0x00, 0x01, 0xa3, 0x82, 0x01, 0xde, 0x30, 0x82, 0x01, 0xda, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x81, 0xa0, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x01, 0x01, 0x04, 0x81, 0x93, 0x30, 0x81, 0x90, 0x30, 0x4d, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x02, 0x86, 0x41, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x73, 0x65, 0x63, 0x75, 0x72, 0x65, 0x2e, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x61, 0x63, 0x65, 0x72, 0x74, 0x2f, 0x67, 0x73, 0x6f, 0x72, 0x67, 0x61, 0x6e, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x76, 0x61, 0x6c, 0x73, 0x68, 0x61, 0x32, 0x67, 0x32, 0x72, 0x31, 0x2e, 0x63, 0x72, 0x74, 0x30, 0x3f, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x01, 0x86, 0x33, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x6f, 0x63, 0x73, 0x70, 0x32, 0x2e, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x73, 0x6f, 0x72, 0x67, 0x61, 0x6e, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x76, 0x61, 0x6c, 0x73, 0x68, 0x61, 0x32, 0x67, 0x32, 0x30, 0x56, 0x06, 0x03, 0x55, 0x1d, 0x20, 0x04, 0x4f, 0x30, 0x4d, 0x30, 0x41, 0x06, 0x09, 0x2b, 0x06, 0x01, 0x04, 0x01, 0xa0, 0x32, 0x01, 0x14, 0x30, 0x34, 0x30, 0x32, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x02, 0x01, 0x16, 0x26, 0x68, 0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x6f, 0x72, 0x79, 0x2f, 0x30, 0x08, 0x06, 0x06, 0x67, 0x81, 0x0c, 0x01, 0x02, 0x02, 0x30, 0x09, 0x06, 0x03, 0x55, 0x1d, 0x13, 0x04, 0x02, 0x30, 0x00, 0x30, 0x49, 0x06, 0x03, 0x55, 0x1d, 0x1f, 0x04, 0x42, 0x30, 0x40, 0x30, 0x3e, 0xa0, 0x3c, 0xa0, 0x3a, 0x86, 0x38, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x72, 0x6c, 0x2e, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x73, 0x2f, 0x67, 0x73, 0x6f, 0x72, 0x67, 0x61, 0x6e, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x76, 0x61, 0x6c, 0x73, 0x68, 0x61, 0x32, 0x67, 0x32, 0x2e, 0x63, 0x72, 0x6c, 0x30, 0x18, 0x06, 0x03, 0x55, 0x1d, 0x11, 0x04, 0x11, 0x30, 0x0f, 0x82, 0x0d, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2e, 0x71, 0x71, 0x2e, 0x63, 0x6f, 0x6d, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x0e, 0x04, 0x16, 0x04, 0x14, 0x28, 0xff, 0xe2, 0x97, 0xf3, 0x6f, 0x2a, 0xef, 0x0f, 0xbc, 0x4c, 0x61, 0x9b, 0xd9, 0x23, 0x7b, 0x3a, 0xef, 0xc2, 0xe7, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x18, 0x30, 0x16, 0x80, 0x14, 0x96, 0xde, 0x61, 0xf1, 0xbd, 0x1c, 0x16, 0x29, 0x53, 0x1c, 0xc0, 0xcc, 0x7d, 0x3b, 0x83, 0x00, 0x40, 0xe6, 0x1a, 0x7c, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b, 0x05, 0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x30, 0xc1, 0xcc, 0xd6, 0x97, 0xf7, 0xf5, 0xa7, 0x93, 0xa5, 0x78, 0xc8, 0xcb, 0x81, 0x44, 0xd4, 0x1f, 0x2a, 0xa6, 0xc1, 0x48, 0xa8, 0x1a, 0xbd, 0x17, 0x10, 0x0e, 0xdf, 0x21, 0xea, 0x02, 0x3e, 0xb3, 0xbd, 0x45, 0x1e, 0x64, 0x85, 0x3f, 0x04, 0x9a, 0xc0, 0x78, 0xf4, 0x81, 0x2e, 0x38, 0x39, 0x3a, 0x04, 0x2d, 0x5f, 0xec, 0xc4, 0x10, 0x57, 0xfb, 0x1b, 0x32, 0xe0, 0x8e, 0xfc, 0xe3, 0x6d, 0x4b, 0xc6, 0xf0, 0x07, 0xb7, 0xc6, 0x19, 0xd7, 0x99, 0x93, 0xbd, 0x60, 0x58, 0xad, 0xbb, 0x94, 0xcf, 0xd8, 0x05, 0x5c, 0x14, 0x70, 0xec, 0x2e, 0xb7, 0x60, 0x52, 0x3c, 0xd3, 0x03, 0xf8, 0xcd, 0xe5, 0x4e, 0x84, 0xcf, 0xef, 0x2f, 0x12, 0xdd, 0x74, 0xfd, 0x95, 0x9d, 0x03, 0xa9, 0x81, 0x18, 0x3a, 0x6e, 0xe6, 0xc2, 0xdd, 0x07, 0x1e, 0xea, 0x8c, 0xe6, 0xd9, 0x31, 0x72, 0x63, 0x25, 0xcd, 0xf2, 0x19, 0xf2, 0x4e, 0x3c, 0x18, 0xfb, 0xb2, 0x74, } certificate2 := []byte{ 0x0b, 0x00, 0x05, 0x5b, 0x00, 0x01, 0x00, 0x04, 0xe4, 0x00, 0x00, 0x77, 0xc1, 0x6b, 0x67, 0xec, 0x34, 0x05, 0xe8, 0x63, 0xfc, 0x74, 0x4b, 0x11, 0x3f, 0x3a, 0xe4, 0x4e, 0x06, 0x89, 0x96, 0x24, 0x3c, 0x15, 0x83, 0xc5, 0x1d, 0xeb, 0xc0, 0x19, 0x71, 0x35, 0x6c, 0xfa, 0xf1, 0x51, 0x06, 0x0e, 0x8e, 0xfb, 0x9b, 0x4e, 0xaa, 0x50, 0x24, 0x77, 0xac, 0x86, 0x14, 0x50, 0x52, 0x35, 0x68, 0x15, 0x9b, 0xdd, 0x8b, 0xdb, 0x83, 0x1d, 0xed, 0x45, 0x05, 0x78, 0x53, 0xd6, 0xc4, 0x21, 0xaf, 0x68, 0x45, 0x91, 0xe7, 0x30, 0x36, 0x4c, 0xb1, 0xfb, 0xf1, 0x65, 0x9a, 0xe4, 0x49, 0x90, 0x1c, 0x0c, 0xa8, 0x63, 0xe9, 0x04, 0xe3, 0x17, 0x61, 0x8d, 0x20, 0x29, 0xca, 0x41, 0xa6, 0x8b, 0x32, 0x53, 0xa5, 0x84, 0x29, 0x5a, 0x62, 0xe7, 0x84, 0x38, 0x32, 0x56, 0xbb, 0x8b, 0xbc, 0x25, 0xc7, 0xa3, 0x28, 0x3b, 0x35, } serverKeyExchange := []byte{ 0x0c, 0x00, 0x01, 0x28, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x01, 0x28, 0x03, 0x00, 0x1d, 0x20, 0x59, 0xa2, 0x0f, 0xc4, 0x7b, 0xd8, 0x03, 0xf6, 0xb0, 0xcf, 0x5d, 0xf0, 0x45, 0x7f, 0x7e, 0xf2, 0x98, 0xab, 0xc0, 0x24, 0xf1, 0xdf, 0xba, 0x63, 0x3e, 0xfb, 0xe5, 0x02, 0x31, 0xcf, 0xd1, 0x05, 0x04, 0x01, 0x01, 0x00, 0x7b, 0x52, 0x9c, 0xe7, 0x54, 0x8b, 0xb0, 0xc9, 0xfd, 0xaf, 0xe2, 0x91, 0x19, 0x9d, 0x6c, 0xb8, 0xbe, 0xa5, 0xe1, 0x48, 0xa0, 0xfd, 0xc5, 0x76, 0x62, 0x47, 0xf2, 0xd1, 0x35, 0x76, 0x4e, 0x33, 0xf4, 0xa1, 0xf1, 0x58, 0xdc, 0xd5, 0x45, 0x3f, 0x76, 0x64, 0x40, 0xba, 0x32, 0xe3, 0x07, 0xb7, 0x4b, 0xbe, 0xe2, 0x77, 0x99, 0xad, 0x11, 0x73, 0x54, 0xe6, 0xbb, 0xfb, 0xd4, 0xb1, 0x83, 0x9f, 0xc6, 0x50, 0xc6, 0xd8, 0xbb, 0x92, 0x0d, 0x93, 0xf9, 0x63, 0x29, 0xf9, 0xc3, 0xce, 0x24, 0x40, 0x29, 0x95, 0x43, 0xf0, 0x32, 0x00, 0x21, 0xde, 0xdf, 0x64, 0xfe, 0xb6, 0x11, 0xa0, 0x11, 0x44, 0x12, 0x2a, 0x1c, 0x96, 0x44, 0x4b, 0x79, 0x31, 0x23, 0x46, 0x4e, 0xe8, 0x16, 0x5b, 0xf5, 0x9a, 0x5f, 0x51, 0x10, 0x5b, 0x11, 0xa3, 0xb8, 0x1f, 0xb7, 0xf1, 0x11, 0xad, 0x05, 0x82, 0x2b, 0xc3, 0x65, 0x8c, 0x41, 0xb4, 0x8e, 0x60, 0x42, 0x89, 0x92, 0xd1, 0x83, 0x73, 0xe7, 0x35, 0xb4, 0xc9, 0xd1, 0xbc, 0x5c, 0x84, 0x5b, 0xdb, 0x44, 0x34, 0xea, 0xd8, 0x06, 0xe4, 0xfb, 0xbd, 0x40, 0x35, 0x18, 0x60, 0x33, 0xb6, 0xed, 0xbc, 0x9b, 0x3a, 0xff, 0x2f, 0xa1, 0xe8, 0x5d, 0x5c, 0xbb, 0xe8, 0xe1, 0xa6, 0xbb, 0x84, 0x0f, 0x50, 0x51, 0x0d, 0xa5, 0x8f, 0x96, 0xb6, 0x35, 0x37, 0x7b, 0x58, 0xaf, 0x4f, 0x77, 0x9d, 0x5d, 0xb2, 0xff, 0x5f, 0xd6, 0xb8, 0x82, 0x64, 0x5f, 0x79, 0xd0, 0x06, 0x44, 0x6d, 0x3a, 0x82, 0x25, 0x21, 0xca, 0xbb, 0xa0, 0x79, 0xdd, 0x6e, 0x15, 0xb6, 0x57, 0x9b, 0x04, 0x84, 0x63, 0x88, 0x1d, 0x41, 0xff, 0xe1, 0x20, 0x61, 0xd5, 0x3f, 0xc7, 0xca, 0x0c, 0xd9, 0xe0, 0x74, 0x86, 0x78, 0xed, 0x60, 0x18, 0x2d, 0x9e, 0x69, 0x66, 0x77, 0xf7, 0xd0, 0xe9, 0x9c, } certificateRequest := []byte{ 0x0d, 0x00, 0x00, 0x26, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x26, 0x03, 0x01, 0x02, 0x40, 0x00, 0x1e, 0x06, 0x01, 0x06, 0x02, 0x06, 0x03, 0x05, 0x01, 0x05, 0x02, 0x05, 0x03, 0x04, 0x01, 0x04, 0x02, 0x04, 0x03, 0x03, 0x01, 0x03, 0x02, 0x03, 0x03, 0x02, 0x01, 0x02, 0x02, 0x02, 0x03, 0x00, 0x00, } serverHelloDone := []byte{ 0x0e, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, } cache.push(certificate2, 0, 2, handshake.TypeCertificate, false) cache.push(serverKeyExchange, 0, 3, handshake.TypeServerKeyExchange, false) cache.push(certificateRequest, 0, 4, handshake.TypeCertificateRequest, false) cache.push(serverHelloDone, 0, 5, handshake.TypeServerHelloDone, false) _, alt, err := flight1Parse(context.TODO(), mockConn, state, cache, cfg) assert.NoError(t, err) assert.Nil(t, alt) cache.push(serverHello, 0, 0, handshake.TypeServerHello, false) cache.push(certificate1, 0, 1, handshake.TypeCertificate, false) _, alt, err = flight1Parse(context.TODO(), mockConn, state, cache, cfg) assert.NoError(t, err) assert.Nil(t, alt) } golang-github-pion-dtls-v3-3.0.7/flight2handler.go000066400000000000000000000037611507057460300217470ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "bytes" "context" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) func flight2Parse( ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) (flightVal, *alert.Alert, error) { seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, ) if !ok { // Client may retransmit the first ClientHello when HelloVerifyRequest is dropped. // Parse as flight 0 in this case. return flight0Parse(ctx, c, state, cache, cfg) } state.handshakeRecvSequence = seq var clientHello *handshake.MessageClientHello // Validate type if clientHello, ok = msgs[handshake.TypeClientHello].(*handshake.MessageClientHello); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } if !clientHello.Version.Equal(protocol.Version1_2) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion } if len(clientHello.Cookie) == 0 { return 0, nil, nil } if !bytes.Equal(state.cookie, clientHello.Cookie) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.AccessDenied}, errCookieMismatch } return flight4, nil, nil } func flight2Generate( _ flightConn, state *State, _ *handshakeCache, _ *handshakeConfig, ) ([]*packet, *alert.Alert, error) { state.handshakeSendSequence = 0 return []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageHelloVerifyRequest{ Version: protocol.Version1_2, Cookie: state.cookie, }, }, }, }, }, nil, nil } golang-github-pion-dtls-v3-3.0.7/flight3handler.go000066400000000000000000000311401507057460300217400ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "bytes" "context" "github.com/pion/dtls/v3/internal/ciphersuite/types" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/crypto/prf" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) //nolint:gocognit,gocyclo,maintidx,cyclop func flight3Parse( ctx context.Context, conn flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) (flightVal, *alert.Alert, error) { // Clients may receive multiple HelloVerifyRequest messages with different cookies. // Clients SHOULD handle this by sending a new ClientHello with a cookie in response // to the new HelloVerifyRequest. RFC 6347 Section 4.2.1 seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true}, ) if ok { if h, msgOk := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); msgOk { // DTLS 1.2 clients must not assume that the server will use the protocol version // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1 if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion } state.cookie = append([]byte{}, h.Cookie...) state.handshakeRecvSequence = seq return flight3, nil, nil } } _, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, ) if !ok { // Don't have enough messages. Keep reading return 0, nil, nil } if serverHelloMsg, msgOk := msgs[handshake.TypeServerHello].(*handshake.MessageServerHello); msgOk { //nolint:nestif if !serverHelloMsg.Version.Equal(protocol.Version1_2) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion } for _, v := range serverHelloMsg.Extensions { switch ext := v.(type) { case *extension.UseSRTP: profile, found := findMatchingSRTPProfile(ext.ProtectionProfiles, cfg.localSRTPProtectionProfiles) if !found { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, errClientNoMatchingSRTPProfile } state.setSRTPProtectionProfile(profile) state.remoteSRTPMasterKeyIdentifier = ext.MasterKeyIdentifier case *extension.UseExtendedMasterSecret: if cfg.extendedMasterSecret != DisableExtendedMasterSecret { state.extendedMasterSecret = true } case *extension.ALPN: if len(ext.ProtocolNameList) > 1 { // This should be exactly 1, the zero case is handle when unmarshalling return 0, &alert.Alert{ Level: alert.Fatal, Description: alert.InternalError, }, extension.ErrALPNInvalidFormat // Meh, internal error? } state.NegotiatedProtocol = ext.ProtocolNameList[0] case *extension.ConnectionID: // Only set connection ID to be sent if client supports connection // IDs. if cfg.connectionIDGenerator != nil { state.remoteConnectionID = ext.CID } } } // If the server doesn't support connection IDs, the client should not // expect one to be sent. if state.remoteConnectionID == nil { state.setLocalConnectionID(nil) } if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errClientRequiredButNoServerEMS } if len(cfg.localSRTPProtectionProfiles) > 0 && state.getSRTPProtectionProfile() == 0 { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errRequestedButNoSRTPExtension } remoteCipherSuite := cipherSuiteForID(CipherSuiteID(*serverHelloMsg.CipherSuiteID), cfg.customCipherSuites) if remoteCipherSuite == nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection } selectedCipherSuite, found := findMatchingCipherSuite([]CipherSuite{remoteCipherSuite}, cfg.localCipherSuites) if !found { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite } state.cipherSuite = selectedCipherSuite state.remoteRandom = serverHelloMsg.Random cfg.log.Tracef("[handshake] use cipher suite: %s", selectedCipherSuite.String()) if len(serverHelloMsg.SessionID) > 0 && bytes.Equal(state.SessionID, serverHelloMsg.SessionID) { return handleResumption(ctx, conn, state, cache, cfg) } if len(state.SessionID) > 0 { cfg.log.Tracef("[handshake] clean old session : %s", state.SessionID) if err := cfg.sessionStore.Del(state.SessionID); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } if cfg.sessionStore == nil { state.SessionID = []byte{} } else { state.SessionID = serverHelloMsg.SessionID } state.masterSecret = []byte{} } if cfg.localPSKCallback != nil { seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite, handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, true}, handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, ) } else { seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, true}, handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, true}, handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, ) } if !ok { // Don't have enough messages. Keep reading return 0, nil, nil } state.handshakeRecvSequence = seq if h, ok := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); ok { state.PeerCertificates = h.Certificate } else if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errInvalidCertificate } if h, ok := msgs[handshake.TypeServerKeyExchange].(*handshake.MessageServerKeyExchange); ok { alertPtr, err := handleServerKeyExchange(conn, state, cfg, h) if err != nil { return 0, alertPtr, err } } if creq, ok := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok { state.remoteCertRequestAlgs = creq.SignatureHashAlgorithms state.remoteRequestedCertificate = true } return flight5, nil, nil } func handleResumption( ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) (flightVal, *alert.Alert, error) { if err := state.initCipherSuite(); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } // Now, encrypted packets can be handled if err := c.handleQueuedPackets(ctx); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } var finished *handshake.MessageFinished if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, ) expectedVerifyData, err := prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc()) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } if !bytes.Equal(expectedVerifyData, finished.VerifyData) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch } clientRandom := state.localRandom.MarshalFixed() cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret) return flight5b, nil, nil } //nolint:cyclop func handleServerKeyExchange( _ flightConn, state *State, cfg *handshakeConfig, keyExchangeMessage *handshake.MessageServerKeyExchange, ) (*alert.Alert, error) { var err error if state.cipherSuite == nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite } if cfg.localPSKCallback != nil { //nolint:nestif var psk []byte if psk, err = cfg.localPSKCallback(keyExchangeMessage.IdentityHint); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.IdentityHint = keyExchangeMessage.IdentityHint switch state.cipherSuite.KeyExchangeAlgorithm() { case types.KeyExchangeAlgorithmPsk: state.preMasterSecret = prf.PSKPreMasterSecret(psk) case (types.KeyExchangeAlgorithmEcdhe | types.KeyExchangeAlgorithmPsk): if state.localKeypair, err = elliptic.GenerateKeypair(keyExchangeMessage.NamedCurve); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.preMasterSecret, err = prf.EcdhePSKPreMasterSecret( psk, keyExchangeMessage.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve, ) if err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } default: return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite } } else { if state.localKeypair, err = elliptic.GenerateKeypair(keyExchangeMessage.NamedCurve); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } if state.preMasterSecret, err = prf.PreMasterSecret( keyExchangeMessage.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve, ); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } return nil, nil //nolint:nilnil } func flight3Generate( _ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig, ) ([]*packet, *alert.Alert, error) { extensions := []extension.Extension{ &extension.SupportedSignatureAlgorithms{ SignatureHashAlgorithms: cfg.localSignatureSchemes, }, &extension.RenegotiationInfo{ RenegotiatedConnection: 0, }, } if state.namedCurve != 0 { extensions = append(extensions, []extension.Extension{ &extension.SupportedEllipticCurves{ EllipticCurves: cfg.ellipticCurves, }, &extension.SupportedPointFormats{ PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, }, }...) } if len(cfg.localSRTPProtectionProfiles) > 0 { extensions = append(extensions, &extension.UseSRTP{ ProtectionProfiles: cfg.localSRTPProtectionProfiles, }) } if cfg.extendedMasterSecret == RequestExtendedMasterSecret || cfg.extendedMasterSecret == RequireExtendedMasterSecret { extensions = append(extensions, &extension.UseExtendedMasterSecret{ Supported: true, }) } if len(cfg.serverName) > 0 { extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName}) } if len(cfg.supportedProtocols) > 0 { extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols}) } // If we sent a connection ID on the first ClientHello, send it on the // second. if state.getLocalConnectionID() != nil { extensions = append(extensions, &extension.ConnectionID{CID: state.getLocalConnectionID()}) } clientHello := &handshake.MessageClientHello{ Version: protocol.Version1_2, SessionID: state.SessionID, Cookie: state.cookie, Random: state.localRandom, CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites), CompressionMethods: defaultCompressionMethods(), Extensions: extensions, } var content handshake.Handshake if cfg.clientHelloMessageHook != nil { content = handshake.Handshake{Message: cfg.clientHelloMessageHook(*clientHello)} } else { content = handshake.Handshake{Message: clientHello} } return []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &content, }, }, }, nil, nil } golang-github-pion-dtls-v3-3.0.7/flight3handler_test.go000066400000000000000000000053431507057460300230050ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "math/rand" "testing" "time" "github.com/pion/dtls/v3/pkg/crypto/elliptic" dtlsnet "github.com/pion/dtls/v3/pkg/net" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" "github.com/pion/transport/v3/dpipe" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/assert" ) // Assert that SupportedEllipticCurves is only sent when a ECC CipherSuite is available. func TestSupportedEllipticCurves(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() expectedCurves := defaultCurves var actualCurves []elliptic.Curve rand.Shuffle(len(expectedCurves), func(i, j int) { expectedCurves[i], expectedCurves[j] = expectedCurves[j], expectedCurves[i] }) clientErr := make(chan error, 1) ca, cb := dpipe.Pipe() caAnalyzer := &connWithCallback{Conn: ca} caAnalyzer.onWrite = func(in []byte) { messages, err := recordlayer.UnpackDatagram(in) assert.NoError(t, err) for i := range messages { h := &handshake.Handshake{} _ = h.Unmarshal(messages[i][recordlayer.FixedHeaderSize:]) if h.Header.Type == handshake.TypeClientHello { //nolint:nestif clientHello := &handshake.MessageClientHello{} msg, err := h.Message.Marshal() assert.NoError(t, err) assert.NoError(t, clientHello.Unmarshal(msg)) for _, e := range clientHello.Extensions { if e.TypeValue() == extension.SupportedEllipticCurvesTypeValue { if c, ok := e.(*extension.SupportedEllipticCurves); ok { actualCurves = c.EllipticCurves } } } } } } go func() { conf := &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: expectedCurves, } if client, err := testClient( ctx, dtlsnet.PacketConnFromConn(caAnalyzer), caAnalyzer.RemoteAddr(), conf, false, ); err != nil { clientErr <- err } else { clientErr <- client.Close() // nolint:errcheck,contextcheck } }() config := &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, } server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) assert.NoError(t, err) assert.NoError(t, server.Close()) assert.NoError(t, <-clientErr) for i := range expectedCurves { assert.Equal(t, expectedCurves[i], actualCurves[i], "curves in SupportedEllipticCurves mismatch") } } golang-github-pion-dtls-v3-3.0.7/flight4bhandler.go000066400000000000000000000116251507057460300221110ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "bytes" "context" "github.com/pion/dtls/v3/pkg/crypto/prf" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) func flight4bParse( _ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) (flightVal, *alert.Alert, error) { _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } var finished *handshake.MessageFinished if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, ) expectedVerifyData, err := prf.VerifyDataClient(state.masterSecret, plainText, state.cipherSuite.HashFunc()) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } if !bytes.Equal(expectedVerifyData, finished.VerifyData) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch } // Other party may re-transmit the last flight. Keep state to be flight4b. return flight4b, nil, nil } //nolint:cyclop func flight4bGenerate( _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) ([]*packet, *alert.Alert, error) { var pkts []*packet extensions := []extension.Extension{&extension.RenegotiationInfo{ RenegotiatedConnection: 0, }} if (cfg.extendedMasterSecret == RequestExtendedMasterSecret || cfg.extendedMasterSecret == RequireExtendedMasterSecret) && state.extendedMasterSecret { extensions = append(extensions, &extension.UseExtendedMasterSecret{ Supported: true, }) } if state.getSRTPProtectionProfile() != 0 { extensions = append(extensions, &extension.UseSRTP{ ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()}, MasterKeyIdentifier: cfg.localSRTPMasterKeyIdentifier, }) } selectedProto, err := extension.ALPNProtocolSelection(cfg.supportedProtocols, state.peerSupportedProtocols) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.NoApplicationProtocol}, err } if selectedProto != "" { extensions = append(extensions, &extension.ALPN{ ProtocolNameList: []string{selectedProto}, }) state.NegotiatedProtocol = selectedProto } cipherSuiteID := uint16(state.cipherSuite.ID()) var serverHello handshake.Handshake serverHelloMessage := &handshake.MessageServerHello{ Version: protocol.Version1_2, Random: state.localRandom, SessionID: state.SessionID, CipherSuiteID: &cipherSuiteID, CompressionMethod: defaultCompressionMethods()[0], Extensions: extensions, } if cfg.serverHelloMessageHook != nil { serverHello = handshake.Handshake{Message: cfg.serverHelloMessageHook(*serverHelloMessage)} } else { serverHello = handshake.Handshake{Message: serverHelloMessage} } serverHello.Header.MessageSequence = uint16(state.handshakeSendSequence) //nolint:gosec // G115 if len(state.localVerifyData) == 0 { plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, ) raw, err := serverHello.Marshal() if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } plainText = append(plainText, raw...) state.localVerifyData, err = prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc()) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &serverHello, }, }, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &protocol.ChangeCipherSpec{}, }, }, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, Epoch: 1, }, Content: &handshake.Handshake{ Message: &handshake.MessageFinished{ VerifyData: state.localVerifyData, }, }, }, shouldEncrypt: true, resetLocalSequenceNumber: true, }, ) return pkts, nil, nil } golang-github-pion-dtls-v3-3.0.7/flight4handler.go000066400000000000000000000417041507057460300217500ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "crypto" "crypto/rand" "crypto/x509" "github.com/pion/dtls/v3/internal/ciphersuite" "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/crypto/prf" "github.com/pion/dtls/v3/pkg/crypto/signaturehash" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) //nolint:gocognit,gocyclo,lll,cyclop,maintidx func flight4Parse( ctx context.Context, conn flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) (flightVal, *alert.Alert, error) { seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, true}, handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, true}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } // Validate type var clientKeyExchange *handshake.MessageClientKeyExchange if clientKeyExchange, ok = msgs[handshake.TypeClientKeyExchange].(*handshake.MessageClientKeyExchange); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } if h, hasCert := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); hasCert { state.PeerCertificates = h.Certificate // If the client offer its certificate, just disable session resumption. // Otherwise, we have to store the certificate identitfication and expire time. // And we have to check whether this certificate expired, revoked or changed. // // https://curl.se/docs/CVE-2016-5419.html state.SessionID = nil } //nolint:nestif if verify, hasVerify := msgs[handshake.TypeCertificateVerify].(*handshake.MessageCertificateVerify); hasVerify { if state.PeerCertificates == nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errCertificateVerifyNoCertificate } plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, ) // Verify that the pair of hash algorithm and signiture is listed. var validSignatureScheme bool for _, ss := range cfg.localSignatureSchemes { if ss.Hash == verify.HashAlgorithm && ss.Signature == verify.SignatureAlgorithm { validSignatureScheme = true break } } if !validSignatureScheme { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes } if err := verifyCertificateVerify( plainText, verify.HashAlgorithm, verify.Signature, state.PeerCertificates, ); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } var chains [][]*x509.Certificate var err error var verified bool if cfg.clientAuth >= VerifyClientCertIfGiven { if chains, err = verifyClientCert(state.PeerCertificates, cfg.clientCAs); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } verified = true } if cfg.verifyPeerCertificate != nil { if err := cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } } state.peerCertificatesVerified = verified } else if state.PeerCertificates != nil { // A certificate was received, but we haven't seen a CertificateVerify // keep reading until we receive one return 0, nil, nil } if !state.cipherSuite.IsInitialized() { //nolint:nestif serverRandom := state.localRandom.MarshalFixed() clientRandom := state.remoteRandom.MarshalFixed() var err error var preMasterSecret []byte if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypePreSharedKey { var psk []byte if psk, err = cfg.localPSKCallback(clientKeyExchange.IdentityHint); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.IdentityHint = clientKeyExchange.IdentityHint switch state.cipherSuite.KeyExchangeAlgorithm() { case CipherSuiteKeyExchangeAlgorithmPsk: preMasterSecret = prf.PSKPreMasterSecret(psk) case (CipherSuiteKeyExchangeAlgorithmPsk | CipherSuiteKeyExchangeAlgorithmEcdhe): if preMasterSecret, err = prf.EcdhePSKPreMasterSecret( psk, clientKeyExchange.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve, ); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } default: return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidCipherSuite } } else { preMasterSecret, err = prf.PreMasterSecret( clientKeyExchange.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve, ) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err } } if state.extendedMasterSecret { var sessionHash []byte sessionHash, err = cache.sessionHash(state.cipherSuite.HashFunc(), cfg.initialEpoch) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.masterSecret, err = prf.ExtendedMasterSecret(preMasterSecret, sessionHash, state.cipherSuite.HashFunc()) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } else { state.masterSecret, err = prf.MasterSecret( preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc(), ) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } if err := state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], false); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret) } if len(state.SessionID) > 0 { s := Session{ ID: state.SessionID, Secret: state.masterSecret, } cfg.log.Tracef("[handshake] save new session: %x", s.ID) if err := cfg.sessionStore.Set(state.SessionID, s); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } // Now, encrypted packets can be handled if err := conn.handleQueuedPackets(ctx); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } seq, msgs, ok = cache.fullPullMap(seq, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } state.handshakeRecvSequence = seq if _, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous { //nolint:nestif if cfg.verifyConnection != nil { stateClone, err := state.clone() if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } if err := cfg.verifyConnection(stateClone); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } } return flight6, nil, nil } switch cfg.clientAuth { case RequireAnyClientCert: if state.PeerCertificates == nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errClientCertificateRequired } case VerifyClientCertIfGiven: if state.PeerCertificates != nil && !state.peerCertificatesVerified { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errClientCertificateNotVerified } case RequireAndVerifyClientCert: if state.PeerCertificates == nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errClientCertificateRequired } if !state.peerCertificatesVerified { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errClientCertificateNotVerified } case NoClientCert, RequestClientCert: // go to flight6 } if cfg.verifyConnection != nil { stateClone, err := state.clone() if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } if err := cfg.verifyConnection(stateClone); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } } return flight6, nil, nil } //nolint:gocognit,cyclop,maintidx func flight4Generate( _ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig, ) ([]*packet, *alert.Alert, error) { extensions := []extension.Extension{} if (cfg.extendedMasterSecret == RequestExtendedMasterSecret || cfg.extendedMasterSecret == RequireExtendedMasterSecret) && state.extendedMasterSecret { extensions = append(extensions, &extension.UseExtendedMasterSecret{ Supported: true, }) } if state.getSRTPProtectionProfile() != 0 { extensions = append(extensions, &extension.UseSRTP{ ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()}, MasterKeyIdentifier: cfg.localSRTPMasterKeyIdentifier, }) } if state.remoteSupportsRenegotiation { extensions = append(extensions, &extension.RenegotiationInfo{ RenegotiatedConnection: 0, }) } if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { extensions = append(extensions, &extension.SupportedPointFormats{ PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, }) } selectedProto, err := extension.ALPNProtocolSelection(cfg.supportedProtocols, state.peerSupportedProtocols) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.NoApplicationProtocol}, err } if selectedProto != "" { extensions = append(extensions, &extension.ALPN{ ProtocolNameList: []string{selectedProto}, }) state.NegotiatedProtocol = selectedProto } // If we have a connection ID generator, we are willing to use connection // IDs. We already know whether the client supports connection IDs from // parsing the ClientHello, so avoid setting local connection ID if the // client won't send it. if cfg.connectionIDGenerator != nil && state.remoteConnectionID != nil { state.setLocalConnectionID(cfg.connectionIDGenerator()) extensions = append(extensions, &extension.ConnectionID{CID: state.getLocalConnectionID()}) } var pkts []*packet cipherSuiteID := uint16(state.cipherSuite.ID()) if cfg.sessionStore != nil { state.SessionID = make([]byte, sessionLength) if _, err := rand.Read(state.SessionID); err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } serverHello := &handshake.MessageServerHello{ Version: protocol.Version1_2, Random: state.localRandom, SessionID: state.SessionID, CipherSuiteID: &cipherSuiteID, CompressionMethod: defaultCompressionMethods()[0], Extensions: extensions, } var content handshake.Handshake if cfg.serverHelloMessageHook != nil { content = handshake.Handshake{Message: cfg.serverHelloMessageHook(*serverHello)} } else { content = handshake.Handshake{Message: serverHello} } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &content, }, }) switch { case state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate: certificate, err := cfg.getCertificate(&ClientHelloInfo{ ServerName: state.serverName, CipherSuites: []ciphersuite.ID{state.cipherSuite.ID()}, RandomBytes: state.remoteRandom.RandomBytes, }) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageCertificate{ Certificate: certificate.Certificate, }, }, }, }) serverRandom := state.localRandom.MarshalFixed() clientRandom := state.remoteRandom.MarshalFixed() signer, ok := certificate.PrivateKey.(crypto.Signer) if !ok { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidPrivateKey } // Find compatible signature scheme signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, signer) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err } signature, err := generateKeySignature( clientRandom[:], serverRandom[:], state.localKeypair.PublicKey, state.namedCurve, signer, signatureHashAlgo.Hash, ) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.localKeySignature = signature pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageServerKeyExchange{ EllipticCurveType: elliptic.CurveTypeNamedCurve, NamedCurve: state.namedCurve, PublicKey: state.localKeypair.PublicKey, HashAlgorithm: signatureHashAlgo.Hash, SignatureAlgorithm: signatureHashAlgo.Signature, Signature: state.localKeySignature, }, }, }, }) if cfg.clientAuth > NoClientCert { // An empty list of certificateAuthorities signals to // the client that it may send any certificate in response // to our request. When we know the CAs we trust, then // we can send them down, so that the client can choose // an appropriate certificate to give to us. var certificateAuthorities [][]byte if cfg.clientCAs != nil { // nolint:staticcheck // ignoring tlsCert.RootCAs.Subjects is deprecated ERR // because cert does not come from SystemCertPool and it's ok if certificate // authorities is empty. certificateAuthorities = cfg.clientCAs.Subjects() } certReq := &handshake.MessageCertificateRequest{ CertificateTypes: []clientcertificate.Type{clientcertificate.RSASign, clientcertificate.ECDSASign}, SignatureHashAlgorithms: cfg.localSignatureSchemes, CertificateAuthoritiesNames: certificateAuthorities, } var content handshake.Handshake if cfg.certificateRequestMessageHook != nil { content = handshake.Handshake{Message: cfg.certificateRequestMessageHook(*certReq)} } else { content = handshake.Handshake{Message: certReq} } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &content, }, }) } case cfg.localPSKIdentityHint != nil || state.cipherSuite.KeyExchangeAlgorithm().Has(CipherSuiteKeyExchangeAlgorithmEcdhe): // To help the client in selecting which identity to use, the server // can provide a "PSK identity hint" in the ServerKeyExchange message. // If no hint is provided and cipher suite doesn't use elliptic curve, // the ServerKeyExchange message is omitted. // // https://tools.ietf.org/html/rfc4279#section-2 srvExchange := &handshake.MessageServerKeyExchange{ IdentityHint: cfg.localPSKIdentityHint, } if state.cipherSuite.KeyExchangeAlgorithm().Has(CipherSuiteKeyExchangeAlgorithmEcdhe) { srvExchange.EllipticCurveType = elliptic.CurveTypeNamedCurve srvExchange.NamedCurve = state.namedCurve srvExchange.PublicKey = state.localKeypair.PublicKey } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: srvExchange, }, }, }) } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageServerHelloDone{}, }, }, }) return pkts, nil, nil } golang-github-pion-dtls-v3-3.0.7/flight4handler_test.go000066400000000000000000000150611507057460300230040ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "crypto/tls" "testing" "time" "github.com/pion/dtls/v3/internal/ciphersuite" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/crypto/selfsign" "github.com/pion/dtls/v3/pkg/crypto/signaturehash" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/assert" ) type flight4TestMockFlightConn struct{} func (f *flight4TestMockFlightConn) notify(context.Context, alert.Level, alert.Description) error { return nil } func (f *flight4TestMockFlightConn) writePackets(context.Context, []*packet) error { return nil } func (f *flight4TestMockFlightConn) recvHandshake() <-chan recvHandshakeState { return nil } func (f *flight4TestMockFlightConn) setLocalEpoch(uint16) {} func (f *flight4TestMockFlightConn) handleQueuedPackets(context.Context) error { return nil } func (f *flight4TestMockFlightConn) sessionKey() []byte { return nil } type flight4TestMockCipherSuite struct { ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256 t *testing.T } func (f *flight4TestMockCipherSuite) IsInitialized() bool { assert.Fail(f.t, "IsInitialized called with Certificate but not CertificateVerify") return true } // Assert that if a Client sends a certificate they // must also send a CertificateVerify message. // The flight4handler must not interact with the CipherSuite // if the CertificateVerify is missing. func TestFlight4_Process_CertificateVerify(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() mockConn := &flight4TestMockFlightConn{} state := &State{ cipherSuite: &flight4TestMockCipherSuite{t: t}, } cache := newHandshakeCache() cfg := &handshakeConfig{} rawCertificate := []byte{ 0x0b, 0x00, 0x01, 0x9b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x9b, 0x00, 0x01, 0x98, 0x00, 0x01, 0x95, 0x30, 0x82, 0x01, 0x91, 0x30, 0x82, 0x01, 0x38, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, 0x11, 0x01, 0x65, 0x03, 0x3f, 0x4d, 0x0b, 0x9a, 0x62, 0x91, 0xdb, 0x4d, 0x28, 0x2c, 0x1f, 0xd6, 0x73, 0x32, 0x30, 0x0a, 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x04, 0x03, 0x02, 0x30, 0x00, 0x30, 0x1e, 0x17, 0x0d, 0x32, 0x32, 0x30, 0x35, 0x31, 0x35, 0x31, 0x38, 0x34, 0x33, 0x35, 0x35, 0x5a, 0x17, 0x0d, 0x32, 0x32, 0x30, 0x36, 0x31, 0x35, 0x31, 0x38, 0x34, 0x33, 0x35, 0x35, 0x5a, 0x30, 0x00, 0x30, 0x59, 0x30, 0x13, 0x06, 0x07, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x02, 0x01, 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07, 0x03, 0x42, 0x00, 0x04, 0xc3, 0xb7, 0x13, 0x1a, 0x0a, 0xfc, 0xd0, 0x82, 0xf8, 0x94, 0x5e, 0xc0, 0x77, 0x07, 0x81, 0x28, 0xc9, 0xcb, 0x08, 0x84, 0x50, 0x6b, 0xf0, 0x22, 0xe8, 0x79, 0xb9, 0x15, 0x33, 0xc4, 0x56, 0xa1, 0xd3, 0x1b, 0x24, 0xe3, 0x61, 0xbd, 0x4d, 0x65, 0x80, 0x6b, 0x5d, 0x96, 0x48, 0xa2, 0x44, 0x9e, 0xce, 0xe8, 0x65, 0xd6, 0x3c, 0xe0, 0x9b, 0x6b, 0xa1, 0x36, 0x34, 0xb2, 0x39, 0xe2, 0x03, 0x00, 0xa3, 0x81, 0x92, 0x30, 0x81, 0x8f, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, 0x02, 0xa4, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01, 0x30, 0x0f, 0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff, 0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0xff, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x0e, 0x04, 0x16, 0x04, 0x14, 0xb1, 0x1a, 0xe3, 0xeb, 0x6f, 0x7c, 0xc3, 0x8f, 0xba, 0x6f, 0x1c, 0xe8, 0xf0, 0x23, 0x08, 0x50, 0x8d, 0x3c, 0xea, 0x31, 0x30, 0x2e, 0x06, 0x03, 0x55, 0x1d, 0x11, 0x01, 0x01, 0xff, 0x04, 0x24, 0x30, 0x22, 0x82, 0x20, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x0a, 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x04, 0x03, 0x02, 0x03, 0x47, 0x00, 0x30, 0x44, 0x02, 0x20, 0x06, 0x31, 0x43, 0xac, 0x03, 0x45, 0x79, 0x3c, 0xd7, 0x5f, 0x6e, 0x6a, 0xf8, 0x0e, 0xfd, 0x35, 0x49, 0xee, 0x1b, 0xbc, 0x47, 0xce, 0xe3, 0x39, 0xec, 0xe4, 0x62, 0xe1, 0x30, 0x1a, 0xa1, 0x89, 0x02, 0x20, 0x35, 0xcd, 0x7a, 0x15, 0x68, 0x09, 0x50, 0x49, 0x9e, 0x3e, 0x05, 0xd7, 0xc2, 0x69, 0x3f, 0x9c, 0x0c, 0x98, 0x92, 0x65, 0xec, 0xae, 0x44, 0xfe, 0xe5, 0x68, 0xb8, 0x09, 0x78, 0x7f, 0x6b, 0x77, } rawClientKeyExchange := []byte{ 0x10, 0x00, 0x00, 0x21, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x21, 0x20, 0x96, 0xed, 0x0c, 0xee, 0xf3, 0x11, 0xb1, 0x9d, 0x8b, 0x1c, 0x02, 0x7f, 0x06, 0x7c, 0x57, 0x7a, 0x14, 0xa6, 0x41, 0xde, 0x63, 0x57, 0x9e, 0xcd, 0x34, 0x54, 0xba, 0x37, 0x4d, 0x34, 0x15, 0x18, } cache.push(rawCertificate, 0, 0, handshake.TypeCertificate, true) cache.push(rawClientKeyExchange, 0, 1, handshake.TypeClientKeyExchange, true) _, _, err := flight4Parse(context.TODO(), mockConn, state, cache, cfg) assert.NoError(t, err) } func TestFlight4_CertificateRequestHook(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() localKeypair, err := elliptic.GenerateKeypair(elliptic.P256) assert.NoError(t, err) mockConn := &flight4TestMockFlightConn{} state := &State{ cipherSuite: &flight4TestMockCipherSuite{t: t}, localKeypair: localKeypair, } cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") assert.NoError(t, err) cfg := &handshakeConfig{ localCertificates: []tls.Certificate{cert}, localSignatureSchemes: signaturehash.Algorithms(), clientAuth: 1, certificateRequestMessageHook: func(mcr handshake.MessageCertificateRequest) handshake.Message { mcr.SignatureHashAlgorithms = []signaturehash.Algorithm{} return &mcr }, } pkts, _, err := flight4Generate(mockConn, state, nil, cfg) assert.NoError(t, err) for _, p := range pkts { if h, ok := p.record.Content.(*handshake.Handshake); ok { //nolint:nestif if h.Message.Type() == handshake.TypeCertificateRequest { mcr := &handshake.MessageCertificateRequest{} msg, err := h.Message.Marshal() assert.NoError(t, err) assert.NoError(t, mcr.Unmarshal(msg)) if len(mcr.SignatureHashAlgorithms) == 0 { return } } } } assert.Fail(t, "hook failed to modify SignatureHashAlgorithms") } golang-github-pion-dtls-v3-3.0.7/flight5bhandler.go000066400000000000000000000045401507057460300221100ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "github.com/pion/dtls/v3/pkg/crypto/prf" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) func flight5bParse( _ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) (flightVal, *alert.Alert, error) { _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-1, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } if _, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } // Other party may re-transmit the last flight. Keep state to be flight5b. return flight5b, nil, nil } func flight5bGenerate( _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) ([]*packet, *alert.Alert, error) { //nolint:gocognit var pkts []*packet pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &protocol.ChangeCipherSpec{}, }, }) if len(state.localVerifyData) == 0 { plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, ) var err error state.localVerifyData, err = prf.VerifyDataClient(state.masterSecret, plainText, state.cipherSuite.HashFunc()) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, Epoch: 1, }, Content: &handshake.Handshake{ Message: &handshake.MessageFinished{ VerifyData: state.localVerifyData, }, }, }, shouldEncrypt: true, resetLocalSequenceNumber: true, }) return pkts, nil, nil } golang-github-pion-dtls-v3-3.0.7/flight5handler.go000066400000000000000000000337321507057460300217530ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "bytes" "context" "crypto" "crypto/x509" "github.com/pion/dtls/v3/pkg/crypto/prf" "github.com/pion/dtls/v3/pkg/crypto/signaturehash" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) func flight5Parse( _ context.Context, conn flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) (flightVal, *alert.Alert, error) { _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } var finished *handshake.MessageFinished if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, ) expectedVerifyData, err := prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc()) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } if !bytes.Equal(expectedVerifyData, finished.VerifyData) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch } if len(state.SessionID) > 0 { s := Session{ ID: state.SessionID, Secret: state.masterSecret, } cfg.log.Tracef("[handshake] save new session: %x", s.ID) if err := cfg.sessionStore.Set(conn.sessionKey(), s); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } return flight5, nil, nil } //nolint:gocognit,cyclop,maintidx func flight5Generate( conn flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) ([]*packet, *alert.Alert, error) { var signer crypto.Signer var pkts []*packet if state.remoteRequestedCertificate { //nolint:nestif _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-2, state.cipherSuite, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}) if !ok { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errClientCertificateRequired } reqInfo := CertificateRequestInfo{} if r, ok2 := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok2 { reqInfo.AcceptableCAs = r.CertificateAuthoritiesNames } else { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errClientCertificateRequired } certificate, err := cfg.getClientCertificate(&reqInfo) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err } if certificate == nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errNotAcceptableCertificateChain } if certificate.Certificate != nil { signer, ok = certificate.PrivateKey.(crypto.Signer) if !ok { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errInvalidPrivateKey } } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageCertificate{ Certificate: certificate.Certificate, }, }, }, }) } clientKeyExchange := &handshake.MessageClientKeyExchange{} if cfg.localPSKCallback == nil { clientKeyExchange.PublicKey = state.localKeypair.PublicKey } else { clientKeyExchange.IdentityHint = cfg.localPSKIdentityHint } if state != nil && state.localKeypair != nil && len(state.localKeypair.PublicKey) > 0 { clientKeyExchange.PublicKey = state.localKeypair.PublicKey } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: clientKeyExchange, }, }, }) serverKeyExchangeData := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, ) serverKeyExchange := &handshake.MessageServerKeyExchange{} // handshakeMessageServerKeyExchange is optional for PSK if len(serverKeyExchangeData) == 0 { alertPtr, err := handleServerKeyExchange(conn, state, cfg, &handshake.MessageServerKeyExchange{}) if err != nil { return nil, alertPtr, err } } else { rawHandshake := &handshake.Handshake{ KeyExchangeAlgorithm: state.cipherSuite.KeyExchangeAlgorithm(), } err := rawHandshake.Unmarshal(serverKeyExchangeData) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, err } switch h := rawHandshake.Message.(type) { case *handshake.MessageServerKeyExchange: serverKeyExchange = h default: return nil, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errInvalidContentType } } // Append not-yet-sent packets merged := []byte{} seqPred := uint16(state.handshakeSendSequence) //nolint:gosec // G115 for _, p := range pkts { h, ok := p.record.Content.(*handshake.Handshake) if !ok { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType } h.Header.MessageSequence = seqPred seqPred++ raw, err := h.Marshal() if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } merged = append(merged, raw...) } if alertPtr, err := initializeCipherSuite(state, cache, cfg, serverKeyExchange, merged); err != nil { return nil, alertPtr, err } // If the client has sent a certificate with signing ability, a digitally-signed // CertificateVerify message is sent to explicitly verify possession of the // private key in the certificate. if state.remoteRequestedCertificate && signer != nil { plainText := append(cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, ), merged...) // Find compatible signature scheme signatureHashAlgo, err := signaturehash.SelectSignatureScheme(state.remoteCertRequestAlgs, signer) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err } certVerify, err := generateCertificateVerify(plainText, signer, signatureHashAlgo.Hash) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.localCertificatesVerify = certVerify pkt := &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageCertificateVerify{ HashAlgorithm: signatureHashAlgo.Hash, SignatureAlgorithm: signatureHashAlgo.Signature, Signature: state.localCertificatesVerify, }, }, }, } pkts = append(pkts, pkt) h, ok := pkt.record.Content.(*handshake.Handshake) if !ok { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType } h.Header.MessageSequence = seqPred // seqPred++ // this is the last use of seqPred raw, err := h.Marshal() if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } merged = append(merged, raw...) } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &protocol.ChangeCipherSpec{}, }, }) if len(state.localVerifyData) == 0 { plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, ) var err error state.localVerifyData, err = prf.VerifyDataClient( state.masterSecret, append(plainText, merged...), state.cipherSuite.HashFunc(), ) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, Epoch: 1, }, Content: &handshake.Handshake{ Message: &handshake.MessageFinished{ VerifyData: state.localVerifyData, }, }, }, shouldWrapCID: len(state.remoteConnectionID) > 0, shouldEncrypt: true, resetLocalSequenceNumber: true, }) return pkts, nil, nil } //nolint:gocognit,cyclop func initializeCipherSuite( state *State, cache *handshakeCache, cfg *handshakeConfig, handshakeKeyExchange *handshake.MessageServerKeyExchange, sendingPlainText []byte, ) (*alert.Alert, error) { if state.cipherSuite.IsInitialized() { return nil, nil //nolint } clientRandom := state.localRandom.MarshalFixed() serverRandom := state.remoteRandom.MarshalFixed() var err error if state.extendedMasterSecret { var sessionHash []byte sessionHash, err = cache.sessionHash(state.cipherSuite.HashFunc(), cfg.initialEpoch, sendingPlainText) if err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.masterSecret, err = prf.ExtendedMasterSecret(state.preMasterSecret, sessionHash, state.cipherSuite.HashFunc()) if err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err } } else { state.masterSecret, err = prf.MasterSecret( state.preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc(), ) if err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { //nolint:nestif // Verify that the pair of hash algorithm and signiture is listed. var validSignatureScheme bool for _, ss := range cfg.localSignatureSchemes { if ss.Hash == handshakeKeyExchange.HashAlgorithm && ss.Signature == handshakeKeyExchange.SignatureAlgorithm { validSignatureScheme = true break } } if !validSignatureScheme { return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes } expectedMsg := valueKeyMessage( clientRandom[:], serverRandom[:], handshakeKeyExchange.PublicKey, handshakeKeyExchange.NamedCurve, ) if err = verifyKeySignature( expectedMsg, handshakeKeyExchange. Signature, handshakeKeyExchange.HashAlgorithm, state.PeerCertificates, ); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } var chains [][]*x509.Certificate if !cfg.insecureSkipVerify { if chains, err = verifyServerCert(state.PeerCertificates, cfg.rootCAs, cfg.serverName); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } } if cfg.verifyPeerCertificate != nil { if err = cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } } } if cfg.verifyConnection != nil { stateClone, errC := state.clone() if errC != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errC } if errC = cfg.verifyConnection(stateClone); errC != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errC } } if err = state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], true); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret) return nil, nil //nolint } golang-github-pion-dtls-v3-3.0.7/flight6handler.go000066400000000000000000000057771507057460300217640ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "github.com/pion/dtls/v3/pkg/crypto/prf" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) func flight6Parse( _ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) (flightVal, *alert.Alert, error) { _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-1, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } if _, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } // Other party may re-transmit the last flight. Keep state to be flight6. return flight6, nil, nil } func flight6Generate( _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) ([]*packet, *alert.Alert, error) { var pkts []*packet pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &protocol.ChangeCipherSpec{}, }, }) if len(state.localVerifyData) == 0 { plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, ) var err error state.localVerifyData, err = prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc()) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, Epoch: 1, }, Content: &handshake.Handshake{ Message: &handshake.MessageFinished{ VerifyData: state.localVerifyData, }, }, }, shouldWrapCID: len(state.remoteConnectionID) > 0, shouldEncrypt: true, resetLocalSequenceNumber: true, }, ) return pkts, nil, nil } golang-github-pion-dtls-v3-3.0.7/flighthandler.go000066400000000000000000000033611507057460300216610ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "github.com/pion/dtls/v3/pkg/protocol/alert" ) // Parse received handshakes and return next flightVal. type flightParser func( context.Context, flightConn, *State, *handshakeCache, *handshakeConfig, ) (flightVal, *alert.Alert, error) // Generate flights. type flightGenerator func(flightConn, *State, *handshakeCache, *handshakeConfig) ([]*packet, *alert.Alert, error) func (f flightVal) getFlightParser() (flightParser, error) { //nolint:cyclop switch f { case flight0: return flight0Parse, nil case flight1: return flight1Parse, nil case flight2: return flight2Parse, nil case flight3: return flight3Parse, nil case flight4: return flight4Parse, nil case flight4b: return flight4bParse, nil case flight5: return flight5Parse, nil case flight5b: return flight5bParse, nil case flight6: return flight6Parse, nil default: return nil, errInvalidFlight } } func (f flightVal) getFlightGenerator() (gen flightGenerator, retransmit bool, err error) { //nolint:cyclop switch f { case flight0: return flight0Generate, true, nil case flight1: return flight1Generate, true, nil case flight2: // https://tools.ietf.org/html/rfc6347#section-3.2.1 // HelloVerifyRequests must not be retransmitted. return flight2Generate, false, nil case flight3: return flight3Generate, true, nil case flight4: return flight4Generate, true, nil case flight4b: return flight4bGenerate, true, nil case flight5: return flight5Generate, true, nil case flight5b: return flight5bGenerate, true, nil case flight6: return flight6Generate, true, nil default: return nil, false, errInvalidFlight } } golang-github-pion-dtls-v3-3.0.7/fragment_buffer.go000066400000000000000000000075111507057460300222030ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) // 2 megabytes. const fragmentBufferMaxSize = 2000000 type fragment struct { recordLayerHeader recordlayer.Header handshakeHeader handshake.Header data []byte } type fragmentBuffer struct { // map of MessageSequenceNumbers that hold slices of fragments cache map[uint16][]*fragment currentMessageSequenceNumber uint16 } func newFragmentBuffer() *fragmentBuffer { return &fragmentBuffer{cache: map[uint16][]*fragment{}} } // current total size of buffer. func (f *fragmentBuffer) size() int { size := 0 for i := range f.cache { for j := range f.cache[i] { size += len(f.cache[i][j].data) } } return size } // Attempts to push a DTLS packet to the fragmentBuffer // when it returns true it means the fragmentBuffer has inserted and the buffer shouldn't be handled // when an error returns it is fatal, and the DTLS connection should be stopped. func (f *fragmentBuffer) push(buf []byte) (isHandshake, isRetransmit bool, err error) { if f.size()+len(buf) >= fragmentBufferMaxSize { return false, false, errFragmentBufferOverflow } frag := new(fragment) if err := frag.recordLayerHeader.Unmarshal(buf); err != nil { return false, false, err } // fragment isn't a handshake, we don't need to handle it if frag.recordLayerHeader.ContentType != protocol.ContentTypeHandshake { return false, false, nil } for buf = buf[recordlayer.FixedHeaderSize:]; len(buf) != 0; frag = new(fragment) { if err := frag.handshakeHeader.Unmarshal(buf); err != nil { return false, false, err } // Fragment is a retransmission. We have already assembled it before successfully isRetransmit = frag.handshakeHeader.FragmentOffset == 0 && frag.handshakeHeader.MessageSequence < f.currentMessageSequenceNumber if _, ok := f.cache[frag.handshakeHeader.MessageSequence]; !ok { f.cache[frag.handshakeHeader.MessageSequence] = []*fragment{} } // end index should be the length of handshake header but if the handshake // was fragmented, we should keep them all end := int(handshake.HeaderLength + frag.handshakeHeader.Length) if size := len(buf); end > size { end = size } // Discard all headers, when rebuilding the packet we will re-build frag.data = append([]byte{}, buf[handshake.HeaderLength:end]...) f.cache[frag.handshakeHeader.MessageSequence] = append(f.cache[frag.handshakeHeader.MessageSequence], frag) buf = buf[end:] } return true, isRetransmit, nil } func (f *fragmentBuffer) pop() (content []byte, epoch uint16) { frags, ok := f.cache[f.currentMessageSequenceNumber] if !ok { return nil, 0 } // Go doesn't support recursive lambdas var appendMessage func(targetOffset uint32) bool rawMessage := []byte{} appendMessage = func(targetOffset uint32) bool { for _, f := range frags { if f.handshakeHeader.FragmentOffset == targetOffset { fragmentEnd := (f.handshakeHeader.FragmentOffset + f.handshakeHeader.FragmentLength) if fragmentEnd != f.handshakeHeader.Length && f.handshakeHeader.FragmentLength != 0 { if !appendMessage(fragmentEnd) { return false } } rawMessage = append(f.data, rawMessage...) return true } } return false } // Recursively collect up if !appendMessage(0) { return nil, 0 } firstHeader := frags[0].handshakeHeader firstHeader.FragmentOffset = 0 firstHeader.FragmentLength = firstHeader.Length rawHeader, err := firstHeader.Marshal() if err != nil { return nil, 0 } messageEpoch := frags[0].recordLayerHeader.Epoch delete(f.cache, f.currentMessageSequenceNumber) f.currentMessageSequenceNumber++ return append(rawHeader, rawMessage...), messageEpoch } golang-github-pion-dtls-v3-3.0.7/fragment_buffer_test.go000066400000000000000000000127511507057460300232440ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "testing" "github.com/stretchr/testify/assert" ) func TestFragmentBuffer(t *testing.T) { for _, test := range []struct { Name string In [][]byte Expected [][]byte Epoch uint16 }{ { Name: "Single Fragment", In: [][]byte{ { 0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00, }, }, Expected: [][]byte{ {0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}, }, Epoch: 0, }, { Name: "Single Fragment Epoch 3", In: [][]byte{ { 0x16, 0xfe, 0xff, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00, }, }, Expected: [][]byte{ {0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}, }, Epoch: 3, }, { Name: "Multiple Fragments", In: [][]byte{ { 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04, }, { 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x05, 0x05, 0x06, 0x07, 0x08, 0x09, }, { 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x05, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, }, }, Expected: [][]byte{ { 0x0b, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, }, }, Epoch: 0, }, { Name: "Multiple Unordered Fragments", In: [][]byte{ { 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04, }, { 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x05, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, }, { 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x05, 0x05, 0x06, 0x07, 0x08, 0x09, }, }, Expected: [][]byte{ { 0x0b, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, }, }, Epoch: 0, }, { Name: "Multiple Handshakes in Single Fragment", In: [][]byte{ { 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x30, /* record header */ 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01, /*handshake msg 1*/ 0x03, 0x00, 0x00, 0x04, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01, /*handshake msg 2*/ 0x03, 0x00, 0x00, 0x04, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01, /*handshake msg 3*/ }, }, Expected: [][]byte{ {0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01}, {0x03, 0x00, 0x00, 0x04, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01}, {0x03, 0x00, 0x00, 0x04, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01}, }, Epoch: 0, }, // Assert that a zero length fragment doesn't cause the fragmentBuffer to enter an infinite loop { Name: "Zero Length Fragment", In: [][]byte{ { 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, }, Expected: [][]byte{ {0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00}, }, Epoch: 0, }, } { fragmentBuffer := newFragmentBuffer() for _, frag := range test.In { status, _, err := fragmentBuffer.push(frag) assert.NoError(t, err) assert.Truef(t, status, "fragmentBuffer didn't accept fragments for '%s'", test.Name) } for _, expected := range test.Expected { out, epoch := fragmentBuffer.pop() assert.Equalf(t, expected, out, "fragmentBuffer '%s' pop should return expected output", test.Name) assert.Equalf(t, test.Epoch, epoch, "fragmentBuffer returend wrong epoch") } frag, _ := fragmentBuffer.pop() assert.Nilf(t, frag, "fragmentBuffer '%s' pop should return nil when no more fragments are available", test.Name) } } func TestFragmentBuffer_Overflow(t *testing.T) { fragmentBuffer := newFragmentBuffer() // Push a buffer that doesn't exceed size limits _, _, err := fragmentBuffer.push([]byte{ 0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00, }) assert.NoError(t, err) // Allocate a buffer that exceeds cache size largeBuffer := make([]byte, fragmentBufferMaxSize) _, _, err = fragmentBuffer.push(largeBuffer) assert.ErrorIs(t, err, errFragmentBufferOverflow, "Pushing a large buffer should return an overflow error") } golang-github-pion-dtls-v3-3.0.7/fuzz_test.go000066400000000000000000000010631507057460300211000ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "os" "testing" ) func FuzzUnmarshalBinary(f *testing.F) { TestResumeClient, err := os.ReadFile("testdata/seed/TestResumeClient.raw") if err != nil { return } f.Add(TestResumeClient) TestResumeServer, err := os.ReadFile("testdata/seed/TestResumeServer.raw") if err != nil { return } f.Add(TestResumeServer) f.Fuzz(func(_ *testing.T, data []byte) { deserialized := &State{} _ = deserialized.UnmarshalBinary(data) }) } golang-github-pion-dtls-v3-3.0.7/go.mod000066400000000000000000000005511507057460300176230ustar00rootroot00000000000000module github.com/pion/dtls/v3 require ( github.com/pion/logging v0.2.4 github.com/pion/transport/v3 v3.0.7 github.com/stretchr/testify v1.10.0 golang.org/x/crypto v0.32.0 golang.org/x/net v0.34.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) go 1.20 golang-github-pion-dtls-v3-3.0.7/go.sum000066400000000000000000000027771507057460300176640ustar00rootroot00000000000000github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0= github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= golang-github-pion-dtls-v3-3.0.7/handshake_cache.go000066400000000000000000000114161507057460300221170ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "sync" "github.com/pion/dtls/v3/pkg/crypto/prf" "github.com/pion/dtls/v3/pkg/protocol/handshake" ) type handshakeCacheItem struct { typ handshake.Type isClient bool epoch uint16 messageSequence uint16 data []byte } type handshakeCachePullRule struct { typ handshake.Type epoch uint16 isClient bool optional bool } type handshakeCache struct { cache []*handshakeCacheItem mu sync.Mutex } func newHandshakeCache() *handshakeCache { return &handshakeCache{} } func (h *handshakeCache) push(data []byte, epoch, messageSequence uint16, typ handshake.Type, isClient bool) { h.mu.Lock() defer h.mu.Unlock() h.cache = append(h.cache, &handshakeCacheItem{ data: append([]byte{}, data...), epoch: epoch, messageSequence: messageSequence, typ: typ, isClient: isClient, }) } // returns a list handshakes that match the requested rules // the list will contain null entries for rules that can't be satisfied // multiple entries may match a rule, but only the last match is returned (ie ClientHello with cookies). func (h *handshakeCache) pull(rules ...handshakeCachePullRule) []*handshakeCacheItem { h.mu.Lock() defer h.mu.Unlock() out := make([]*handshakeCacheItem, len(rules)) for i, r := range rules { for _, c := range h.cache { if c.typ == r.typ && c.isClient == r.isClient && c.epoch == r.epoch { switch { case out[i] == nil: out[i] = c case out[i].messageSequence < c.messageSequence: out[i] = c } } } } return out } // fullPullMap pulls all handshakes between rules[0] to rules[len(rules)-1] as map. // //nolint:cyclop func (h *handshakeCache) fullPullMap( startSeq int, cipherSuite CipherSuite, rules ...handshakeCachePullRule, ) (int, map[handshake.Type]handshake.Message, bool) { h.mu.Lock() defer h.mu.Unlock() ci := make(map[handshake.Type]*handshakeCacheItem) for _, rule := range rules { var item *handshakeCacheItem for _, c := range h.cache { if c.typ == rule.typ && c.isClient == rule.isClient && c.epoch == rule.epoch { switch { case item == nil: item = c case item.messageSequence < c.messageSequence: item = c } } } if !rule.optional && item == nil { // Missing mandatory message. return startSeq, nil, false } ci[rule.typ] = item } out := make(map[handshake.Type]handshake.Message) seq := startSeq ok := false for _, r := range rules { typ := r.typ i := ci[typ] if i == nil { continue } var keyExchangeAlgorithm CipherSuiteKeyExchangeAlgorithm if cipherSuite != nil { keyExchangeAlgorithm = cipherSuite.KeyExchangeAlgorithm() } rawHandshake := &handshake.Handshake{ KeyExchangeAlgorithm: keyExchangeAlgorithm, } if err := rawHandshake.Unmarshal(i.data); err != nil { return startSeq, nil, false } if uint16(seq) != rawHandshake.Header.MessageSequence { //nolint:gosec // G115 // There is a gap. Some messages are not arrived. return startSeq, nil, false } seq++ ok = true out[typ] = rawHandshake.Message } if !ok { return seq, nil, false } return seq, out, true } // pullAndMerge calls pull and then merges the results, ignoring any null entries. func (h *handshakeCache) pullAndMerge(rules ...handshakeCachePullRule) []byte { merged := []byte{} for _, p := range h.pull(rules...) { if p != nil { merged = append(merged, p.data...) } } return merged } // sessionHash returns the session hash for Extended Master Secret support // https://tools.ietf.org/html/draft-ietf-tls-session-hash-06#section-4 func (h *handshakeCache) sessionHash(hf prf.HashFunc, epoch uint16, additional ...[]byte) ([]byte, error) { merged := []byte{} // Order defined by https://tools.ietf.org/html/rfc5246#section-7.3 handshakeBuffer := h.pull( handshakeCachePullRule{handshake.TypeClientHello, epoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, epoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, epoch, false, false}, handshakeCachePullRule{handshake.TypeServerKeyExchange, epoch, false, false}, handshakeCachePullRule{handshake.TypeCertificateRequest, epoch, false, false}, handshakeCachePullRule{handshake.TypeServerHelloDone, epoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, epoch, true, false}, handshakeCachePullRule{handshake.TypeClientKeyExchange, epoch, true, false}, ) for _, p := range handshakeBuffer { if p == nil { continue } merged = append(merged, p.data...) } for _, a := range additional { merged = append(merged, a...) } hash := hf() if _, err := hash.Write(merged); err != nil { return []byte{}, err } return hash.Sum(nil), nil } golang-github-pion-dtls-v3-3.0.7/handshake_cache_test.go000066400000000000000000000160061507057460300231560ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "testing" "github.com/pion/dtls/v3/internal/ciphersuite" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/stretchr/testify/assert" ) func TestHandshakeCacheSinglePush(t *testing.T) { for _, test := range []struct { Name string Rule []handshakeCachePullRule Input []handshakeCacheItem Expected []byte }{ { Name: "Single Push", Input: []handshakeCacheItem{ {0, true, 0, 0, []byte{0x00}}, }, Rule: []handshakeCachePullRule{ {0, 0, true, false}, }, Expected: []byte{0x00}, }, { Name: "Multi Push", Input: []handshakeCacheItem{ {0, true, 0, 0, []byte{0x00}}, {1, true, 0, 1, []byte{0x01}}, {2, true, 0, 2, []byte{0x02}}, }, Rule: []handshakeCachePullRule{ {0, 0, true, false}, {1, 0, true, false}, {2, 0, true, false}, }, Expected: []byte{0x00, 0x01, 0x02}, }, { Name: "Multi Push, Rules set order", Input: []handshakeCacheItem{ {2, true, 0, 2, []byte{0x02}}, {0, true, 0, 0, []byte{0x00}}, {1, true, 0, 1, []byte{0x01}}, }, Rule: []handshakeCachePullRule{ {0, 0, true, false}, {1, 0, true, false}, {2, 0, true, false}, }, Expected: []byte{0x00, 0x01, 0x02}, }, { Name: "Multi Push, Dupe Seqnum", Input: []handshakeCacheItem{ {0, true, 0, 0, []byte{0x00}}, {1, true, 0, 1, []byte{0x01}}, {1, true, 0, 1, []byte{0x01}}, }, Rule: []handshakeCachePullRule{ {0, 0, true, false}, {1, 0, true, false}, }, Expected: []byte{0x00, 0x01}, }, { Name: "Multi Push, Dupe Seqnum Client/Server", Input: []handshakeCacheItem{ {0, true, 0, 0, []byte{0x00}}, {1, true, 0, 1, []byte{0x01}}, {1, false, 0, 1, []byte{0x02}}, }, Rule: []handshakeCachePullRule{ {0, 0, true, false}, {1, 0, true, false}, {1, 0, false, false}, }, Expected: []byte{0x00, 0x01, 0x02}, }, { Name: "Multi Push, Dupe Seqnum with Unique HandshakeType", Input: []handshakeCacheItem{ {1, true, 0, 0, []byte{0x00}}, {2, true, 0, 1, []byte{0x01}}, {3, false, 0, 0, []byte{0x02}}, }, Rule: []handshakeCachePullRule{ {1, 0, true, false}, {2, 0, true, false}, {3, 0, false, false}, }, Expected: []byte{0x00, 0x01, 0x02}, }, { Name: "Multi Push, Wrong epoch", Input: []handshakeCacheItem{ {1, true, 0, 0, []byte{0x00}}, {2, true, 1, 1, []byte{0x01}}, {2, true, 0, 2, []byte{0x11}}, {3, false, 0, 0, []byte{0x02}}, {3, false, 1, 0, []byte{0x12}}, {3, false, 2, 0, []byte{0x12}}, }, Rule: []handshakeCachePullRule{ {1, 0, true, false}, {2, 1, true, false}, {3, 0, false, false}, }, Expected: []byte{0x00, 0x01, 0x02}, }, } { h := newHandshakeCache() for _, i := range test.Input { h.push(i.data, i.epoch, i.messageSequence, i.typ, i.isClient) } verifyData := h.pullAndMerge(test.Rule...) assert.Equal(t, test.Expected, verifyData) } } func TestHandshakeCacheSessionHash(t *testing.T) { for _, test := range []struct { Name string Rule []handshakeCachePullRule Input []handshakeCacheItem Expected []byte }{ { Name: "Standard Handshake", Input: []handshakeCacheItem{ {handshake.TypeClientHello, true, 0, 0, []byte{0x00}}, {handshake.TypeServerHello, false, 0, 1, []byte{0x01}}, {handshake.TypeCertificate, false, 0, 2, []byte{0x02}}, {handshake.TypeServerKeyExchange, false, 0, 3, []byte{0x03}}, {handshake.TypeServerHelloDone, false, 0, 4, []byte{0x04}}, {handshake.TypeClientKeyExchange, true, 0, 5, []byte{0x05}}, }, Expected: []byte{ 0x17, 0xe8, 0x8d, 0xb1, 0x87, 0xaf, 0xd6, 0x2c, 0x16, 0xe5, 0xde, 0xbf, 0x3e, 0x65, 0x27, 0xcd, 0x00, 0x6b, 0xc0, 0x12, 0xbc, 0x90, 0xb5, 0x1a, 0x81, 0x0c, 0xd8, 0x0c, 0x2d, 0x51, 0x1f, 0x43, }, }, { Name: "Handshake With Client Cert Request", Input: []handshakeCacheItem{ {handshake.TypeClientHello, true, 0, 0, []byte{0x00}}, {handshake.TypeServerHello, false, 0, 1, []byte{0x01}}, {handshake.TypeCertificate, false, 0, 2, []byte{0x02}}, {handshake.TypeServerKeyExchange, false, 0, 3, []byte{0x03}}, {handshake.TypeCertificateRequest, false, 0, 4, []byte{0x04}}, {handshake.TypeServerHelloDone, false, 0, 5, []byte{0x05}}, {handshake.TypeClientKeyExchange, true, 0, 6, []byte{0x06}}, }, Expected: []byte{ 0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, 0x56, 0xb9, 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, 0x25, 0x74, 0x9a, 0x6b, }, }, { Name: "Handshake Ignores after ClientKeyExchange", Input: []handshakeCacheItem{ {handshake.TypeClientHello, true, 0, 0, []byte{0x00}}, {handshake.TypeServerHello, false, 0, 1, []byte{0x01}}, {handshake.TypeCertificate, false, 0, 2, []byte{0x02}}, {handshake.TypeServerKeyExchange, false, 0, 3, []byte{0x03}}, {handshake.TypeCertificateRequest, false, 0, 4, []byte{0x04}}, {handshake.TypeServerHelloDone, false, 0, 5, []byte{0x05}}, {handshake.TypeClientKeyExchange, true, 0, 6, []byte{0x06}}, {handshake.TypeCertificateVerify, true, 0, 7, []byte{0x07}}, {handshake.TypeFinished, true, 1, 7, []byte{0x08}}, {handshake.TypeFinished, false, 1, 7, []byte{0x09}}, }, Expected: []byte{ 0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, 0x56, 0xb9, 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, 0x25, 0x74, 0x9a, 0x6b, }, }, { Name: "Handshake Ignores wrong epoch", Input: []handshakeCacheItem{ {handshake.TypeClientHello, true, 0, 0, []byte{0x00}}, {handshake.TypeServerHello, false, 0, 1, []byte{0x01}}, {handshake.TypeCertificate, false, 0, 2, []byte{0x02}}, {handshake.TypeServerKeyExchange, false, 0, 3, []byte{0x03}}, {handshake.TypeCertificateRequest, false, 0, 4, []byte{0x04}}, {handshake.TypeServerHelloDone, false, 0, 5, []byte{0x05}}, {handshake.TypeClientKeyExchange, true, 0, 6, []byte{0x06}}, {handshake.TypeCertificateVerify, true, 0, 7, []byte{0x07}}, {handshake.TypeFinished, true, 0, 7, []byte{0xf0}}, {handshake.TypeFinished, false, 0, 7, []byte{0xf1}}, {handshake.TypeFinished, true, 1, 7, []byte{0x08}}, {handshake.TypeFinished, false, 1, 7, []byte{0x09}}, {handshake.TypeFinished, true, 0, 7, []byte{0xf0}}, {handshake.TypeFinished, false, 0, 7, []byte{0xf1}}, }, Expected: []byte{ 0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, 0x56, 0xb9, 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, 0x25, 0x74, 0x9a, 0x6b, }, }, } { h := newHandshakeCache() for _, i := range test.Input { h.push(i.data, i.epoch, i.messageSequence, i.typ, i.isClient) } cipherSuite := ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{} verifyData, err := h.sessionHash(cipherSuite.HashFunc(), 0) assert.NoError(t, err) assert.Equal(t, test.Expected, verifyData, "handshakeCacheSessionHash") } } golang-github-pion-dtls-v3-3.0.7/handshake_test.go000066400000000000000000000034041507057460300220310ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "testing" "time" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/stretchr/testify/assert" ) func TestHandshakeMessage(t *testing.T) { rawHandshakeMessage := []byte{ 0x01, 0x00, 0x00, 0x29, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x29, 0xfe, 0xfd, 0xb6, 0x2f, 0xce, 0x5c, 0x42, 0x54, 0xff, 0x86, 0xe1, 0x24, 0x41, 0x91, 0x42, 0x62, 0x15, 0xad, 0x16, 0xc9, 0x15, 0x8d, 0x95, 0x71, 0x8a, 0xbb, 0x22, 0xd7, 0x47, 0xec, 0xd8, 0x3d, 0xdc, 0x4b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, } parsedHandshake := &handshake.Handshake{ Header: handshake.Header{ Length: 0x29, FragmentLength: 0x29, Type: handshake.TypeClientHello, }, Message: &handshake.MessageClientHello{ Version: protocol.Version{Major: 0xFE, Minor: 0xFD}, Random: handshake.Random{ GMTUnixTime: time.Unix(3056586332, 0), RandomBytes: [28]byte{ 0x42, 0x54, 0xff, 0x86, 0xe1, 0x24, 0x41, 0x91, 0x42, 0x62, 0x15, 0xad, 0x16, 0xc9, 0x15, 0x8d, 0x95, 0x71, 0x8a, 0xbb, 0x22, 0xd7, 0x47, 0xec, 0xd8, 0x3d, 0xdc, 0x4b, }, }, SessionID: []byte{}, Cookie: []byte{}, CipherSuiteIDs: []uint16{}, CompressionMethods: []*protocol.CompressionMethod{}, Extensions: []extension.Extension{}, }, } h := &handshake.Handshake{} assert.NoError(t, h.Unmarshal(rawHandshakeMessage)) assert.Equal(t, parsedHandshake, h, "handshakeMessageClientHello unmarshal") raw, err := h.Marshal() assert.NoError(t, err) assert.Equal(t, rawHandshakeMessage, raw, "handshakeMessageClientHello marshal") } golang-github-pion-dtls-v3-3.0.7/handshaker.go000066400000000000000000000252601507057460300211600ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "crypto/tls" "crypto/x509" "fmt" "io" "sync" "time" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/crypto/signaturehash" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/logging" ) // [RFC6347 Section-4.2.4] // +-----------+ // +---> | PREPARING | <--------------------+ // | +-----------+ | // | | | // | | Buffer next flight | // | | | // | \|/ | // | +-----------+ | // | | SENDING |<------------------+ | Send // | +-----------+ | | HelloRequest // Receive | | | | // next | | Send flight | | or // flight | +--------+ | | // | | | Set retransmit timer | | Receive // | | \|/ | | HelloRequest // | | +-----------+ | | Send // +--)--| WAITING |-------------------+ | ClientHello // | | +-----------+ Timer expires | | // | | | | | // | | +------------------------+ | // Receive | | Send Read retransmit | // last | | last | // flight | | flight | // | | | // \|/\|/ | // +-----------+ | // | FINISHED | -------------------------------+ // +-----------+ // | /|\ // | | // +---+ // Read retransmit // Retransmit last flight type handshakeState uint8 const ( handshakeErrored handshakeState = iota handshakePreparing handshakeSending handshakeWaiting handshakeFinished ) func (s handshakeState) String() string { switch s { case handshakeErrored: return "Errored" case handshakePreparing: return "Preparing" case handshakeSending: return "Sending" case handshakeWaiting: return "Waiting" case handshakeFinished: return "Finished" default: return "Unknown" } } type handshakeFSM struct { currentFlight flightVal flights []*packet retransmit bool retransmitInterval time.Duration state *State cache *handshakeCache cfg *handshakeConfig closed chan struct{} } type handshakeConfig struct { localPSKCallback PSKCallback localPSKIdentityHint []byte localCipherSuites []CipherSuite // Available CipherSuites localSignatureSchemes []signaturehash.Algorithm // Available signature schemes extendedMasterSecret ExtendedMasterSecretType // Policy for the Extended Master Support extension localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support localSRTPMasterKeyIdentifier []byte serverName string supportedProtocols []string clientAuth ClientAuthType // If we are a client should we request a client certificate localCertificates []tls.Certificate nameToCertificate map[string]*tls.Certificate insecureSkipVerify bool verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error verifyConnection func(*State) error sessionStore SessionStore rootCAs *x509.CertPool clientCAs *x509.CertPool initialRetransmitInterval time.Duration disableRetransmitBackoff bool customCipherSuites func() []CipherSuite ellipticCurves []elliptic.Curve insecureSkipHelloVerify bool connectionIDGenerator func() []byte helloRandomBytesGenerator func() [handshake.RandomBytesLength]byte onFlightState func(flightVal, handshakeState) log logging.LeveledLogger keyLogWriter io.Writer localGetCertificate func(*ClientHelloInfo) (*tls.Certificate, error) localGetClientCertificate func(*CertificateRequestInfo) (*tls.Certificate, error) initialEpoch uint16 mu sync.Mutex clientHelloMessageHook func(handshake.MessageClientHello) handshake.Message serverHelloMessageHook func(handshake.MessageServerHello) handshake.Message certificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message resumeState *State } type flightConn interface { notify(ctx context.Context, level alert.Level, desc alert.Description) error writePackets(context.Context, []*packet) error recvHandshake() <-chan recvHandshakeState setLocalEpoch(epoch uint16) handleQueuedPackets(context.Context) error sessionKey() []byte } func (c *handshakeConfig) writeKeyLog(label string, clientRandom, secret []byte) { if c.keyLogWriter == nil { return } c.mu.Lock() defer c.mu.Unlock() _, err := c.keyLogWriter.Write([]byte(fmt.Sprintf("%s %x %x\n", label, clientRandom, secret))) if err != nil { c.log.Debugf("failed to write key log file: %s", err) } } func srvCliStr(isClient bool) string { if isClient { return "client" } return "server" } func newHandshakeFSM( s *State, cache *handshakeCache, cfg *handshakeConfig, initialFlight flightVal, ) *handshakeFSM { return &handshakeFSM{ currentFlight: initialFlight, state: s, cache: cache, cfg: cfg, retransmitInterval: cfg.initialRetransmitInterval, closed: make(chan struct{}), } } func (s *handshakeFSM) Run(ctx context.Context, conn flightConn, initialState handshakeState) error { state := initialState defer func() { close(s.closed) }() for { s.cfg.log.Tracef("[handshake:%s] %s: %s", srvCliStr(s.state.isClient), s.currentFlight.String(), state.String()) if s.cfg.onFlightState != nil { s.cfg.onFlightState(s.currentFlight, state) } var err error switch state { case handshakePreparing: state, err = s.prepare(ctx, conn) case handshakeSending: state, err = s.send(ctx, conn) case handshakeWaiting: state, err = s.wait(ctx, conn) case handshakeFinished: state, err = s.finish(ctx, conn) default: return errInvalidFSMTransition } if err != nil { return err } } } func (s *handshakeFSM) Done() <-chan struct{} { return s.closed } func (s *handshakeFSM) prepare(ctx context.Context, conn flightConn) (handshakeState, error) { s.flights = nil // Prepare flights var ( dtlsAlert *alert.Alert err error pkts []*packet ) gen, retransmit, errFlight := s.currentFlight.getFlightGenerator() if errFlight != nil { err = errFlight dtlsAlert = &alert.Alert{Level: alert.Fatal, Description: alert.InternalError} } else { pkts, dtlsAlert, err = gen(conn, s.state, s.cache, s.cfg) s.retransmit = retransmit } if dtlsAlert != nil { if alertErr := conn.notify(ctx, dtlsAlert.Level, dtlsAlert.Description); alertErr != nil { if err != nil { err = alertErr } } } if err != nil { return handshakeErrored, err } s.flights = pkts epoch := s.cfg.initialEpoch nextEpoch := epoch for _, p := range s.flights { p.record.Header.Epoch += epoch if p.record.Header.Epoch > nextEpoch { nextEpoch = p.record.Header.Epoch } if h, ok := p.record.Content.(*handshake.Handshake); ok { h.Header.MessageSequence = uint16(s.state.handshakeSendSequence) //nolint:gosec // G115 s.state.handshakeSendSequence++ } } if epoch != nextEpoch { s.cfg.log.Tracef("[handshake:%s] -> changeCipherSpec (epoch: %d)", srvCliStr(s.state.isClient), nextEpoch) conn.setLocalEpoch(nextEpoch) } return handshakeSending, nil } func (s *handshakeFSM) send(ctx context.Context, c flightConn) (handshakeState, error) { // Send flights if err := c.writePackets(ctx, s.flights); err != nil { return handshakeErrored, err } if s.currentFlight.isLastSendFlight() { return handshakeFinished, nil } return handshakeWaiting, nil } func (s *handshakeFSM) wait(ctx context.Context, conn flightConn) (handshakeState, error) { //nolint:gocognit,cyclop parse, errFlight := s.currentFlight.getFlightParser() if errFlight != nil { if alertErr := conn.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil { return handshakeErrored, alertErr } return handshakeErrored, errFlight } retransmitTimer := time.NewTimer(s.retransmitInterval) for { select { case state := <-conn.recvHandshake(): if state.isRetransmit { close(state.done) return handshakeSending, nil } nextFlight, alert, err := parse(ctx, conn, s.state, s.cache, s.cfg) s.retransmitInterval = s.cfg.initialRetransmitInterval close(state.done) if alert != nil { if alertErr := conn.notify(ctx, alert.Level, alert.Description); alertErr != nil { if err != nil { err = alertErr } } } if err != nil { return handshakeErrored, err } if nextFlight == 0 { break } s.cfg.log.Tracef( "[handshake:%s] %s -> %s", srvCliStr(s.state.isClient), s.currentFlight.String(), nextFlight.String(), ) if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight { return handshakeFinished, nil } s.currentFlight = nextFlight return handshakePreparing, nil case <-retransmitTimer.C: if !s.retransmit { return handshakeWaiting, nil } // RFC 4347 4.2.4.1: // Implementations SHOULD use an initial timer value of 1 second (the minimum defined in RFC 2988 [RFC2988]) // and double the value at each retransmission, up to no less than the RFC 2988 maximum of 60 seconds. if !s.cfg.disableRetransmitBackoff { s.retransmitInterval *= 2 } if s.retransmitInterval > time.Second*60 { s.retransmitInterval = time.Second * 60 } return handshakeSending, nil case <-ctx.Done(): s.retransmitInterval = s.cfg.initialRetransmitInterval return handshakeErrored, ctx.Err() } } } func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState, error) { select { case state := <-c.recvHandshake(): close(state.done) if s.state.isClient { return handshakeFinished, nil } else { return handshakeSending, nil } case <-ctx.Done(): return handshakeErrored, ctx.Err() } } golang-github-pion-dtls-v3-3.0.7/handshaker_test.go000066400000000000000000000275361507057460300222270ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "bytes" "context" "crypto/tls" "errors" "sync" "testing" "time" "github.com/pion/dtls/v3/pkg/crypto/selfsign" "github.com/pion/dtls/v3/pkg/crypto/signaturehash" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" "github.com/pion/logging" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/assert" ) const nonZeroRetransmitInterval = 100 * time.Millisecond // Test that writes to the key log are in the correct format and only applies // when a key log writer is given. func TestWriteKeyLog(t *testing.T) { var buf bytes.Buffer cfg := handshakeConfig{ keyLogWriter: &buf, } cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF}) // Secrets follow the format