pax_global_header00006660000000000000000000000064151726773470014535gustar00rootroot0000000000000052 comment=e9ad547c02f8040b4a4c08db79c4bf1241b8062e golang-github-olekukonko-errors-1.3.0/000077500000000000000000000000001517267734700177565ustar00rootroot00000000000000golang-github-olekukonko-errors-1.3.0/.github/000077500000000000000000000000001517267734700213165ustar00rootroot00000000000000golang-github-olekukonko-errors-1.3.0/.github/workflows/000077500000000000000000000000001517267734700233535ustar00rootroot00000000000000golang-github-olekukonko-errors-1.3.0/.github/workflows/go.yml000066400000000000000000000013671517267734700245120ustar00rootroot00000000000000name: Go on: push: branches: [ "main" ] pull_request: branches: [ "main" ] jobs: test: name: Test (Go ${{ matrix.go-version }}) runs-on: ubuntu-latest strategy: fail-fast: false matrix: go-version: - '1.21' # Represents the pool_below_1_24 path (SetFinalizer) - '1.24' # Represents the pool_above_1_24 path (AddCleanup) steps: - uses: actions/checkout@v4 - name: Set up Go ${{ matrix.go-version }} uses: actions/setup-go@v4 with: go-version: ${{ matrix.go-version }} cache: false - name: Build run: go build -v ./... - name: Test (race + cover) run: go test -count=1 -race -timeout 120s -cover -v ./...golang-github-olekukonko-errors-1.3.0/.gitignore000077500000000000000000000000721517267734700217500ustar00rootroot00000000000000# Created by .ignore support plugin (hsz.mobi) .idea/ tmp/golang-github-olekukonko-errors-1.3.0/LICENSE000066400000000000000000000020541517267734700207640ustar00rootroot00000000000000MIT License Copyright (c) 2025 Oleku Konko Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. golang-github-olekukonko-errors-1.3.0/README.md000066400000000000000000000415101517267734700212360ustar00rootroot00000000000000# errors — production-grade error handling for Go [![Go Reference](https://pkg.go.dev/badge/github.com/olekukonko/errors.svg)](https://pkg.go.dev/github.com/olekukonko/errors) [![Go Report Card](https://goreportcard.com/badge/github.com/olekukonko/errors)](https://goreportcard.com/report/github.com/olekukonko/errors) [![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE) [![Go 1.21+](https://img.shields.io/badge/go-1.21+-blue.svg)](https://golang.org/dl/) A feature-complete error handling library for Go. Fully compatible with `errors.Is`, `errors.As`, and `errors.Unwrap`. Optimised for high-throughput systems with object pooling, hybrid context storage, and inlining-immune stack capture. --- ## Contents - [Installation](#installation) - [Package overview](#package-overview) - [Core — `errors`](#core--errors) - [Creating errors](#creating-errors) - [Stack traces](#stack-traces) - [Context](#context) - [Wrapping and chaining](#wrapping-and-chaining) - [Sentinel errors](#sentinel-errors) - [Type assertions — Is / As](#type-assertions--is--as) - [Multi-error aggregation](#multi-error-aggregation) - [Retry](#retry) - [Chain execution](#chain-execution) - [Channel utilities and streaming](#channel-utilities-and-streaming) - [HTTP helpers](#http-helpers) - [Concurrent group](#concurrent-group) - [Inspect](#inspect) - [slog integration](#slog-integration) - [Pool management](#pool-management) - [Management — `errmgr`](#management--errmgr) - [Performance](#performance) - [Migration guide](#migration-guide) - [FAQ](#faq) --- ## Installation ```bash go get github.com/olekukonko/errors@latest ``` Requires Go 1.21 or later. --- ## Package overview | Package | Purpose | |---|---| | `errors` | Core error type, wrapping, context, stack traces, retry, chain, multi-error, channel utilities | | `errmgr` | Parameterised error templates, occurrence monitoring, threshold alerting | --- ## Core — `errors` ### Creating errors ```go // Fast — no stack trace, 0 allocations with pooling err := errors.New("connection failed") // Formatted — full fmt verb support including %w err := errors.Newf("user %s not found", "alice") err := errors.Errorf("query failed: %w", cause) // alias of Newf // With stack trace err := errors.Trace("critical issue") err := errors.Tracef("query %s failed: %w", query, cause) // Named — useful for sentinel-style matching err := errors.Named("AuthError") // Standard library compatible err := errors.Std("connection failed") // returns plain error err := errors.Stdf("error %s", "detail") // formatted plain error ``` ### Stack traces ```go // Capture at creation err := errors.Trace("critical issue") // Add to an existing error err = err.WithStack() // Read frames for _, frame := range err.Stack() { fmt.Println(frame) // "main.go:42 main.main" } // Lightweight version (file:line only, no function names) for _, frame := range err.FastStack() { fmt.Println(frame) } ``` Stack capture is immune to compiler inlining — frames are collected from the physical call stack and trimmed by slice arithmetic, not by skip count. ### Context ```go err := errors.New("processing failed"). With("user_id", "123"). With("attempt", 3). With("retryable", true) // Read back ctx := errors.Context(err) // map[user_id:123 attempt:3 retryable:true] // Check for a key if err.HasContextKey("user_id") { ... } // Variadic bulk attach err.With("k1", v1, "k2", v2) // Semantic helpers err.WithCode(500) err.WithCategory("network") err.WithTimeout() err.WithRetryable() ``` The first four context items are stored in a fixed-size array (no allocation). Items beyond four spill to a map. ### Wrapping and chaining ```go lowErr := errors.New("connection timeout").With("server", "db01") bizErr := errors.New("failed to load user").Wrap(lowErr) apiErr := errors.Wrapf(bizErr, "request failed: %w", bizErr) // Traverse for i, e := range errors.UnwrapAll(apiErr) { fmt.Printf("%d. %s\n", i+1, e) } // 1. request failed: ... // 2. failed to load user // 3. connection timeout ``` ### Sentinel errors `Const` creates a stable, pointer-comparable sentinel safe for package-level variables. ```go var ( ErrNotFound = errors.Const("not_found", "resource not found") ErrForbidden = errors.Const("forbidden", "access denied") ) // Match anywhere in a chain if errors.Is(err, ErrNotFound) { ... } // Add call-site context without losing the sentinel err := ErrNotFound.With("user 42 not found") errors.Is(err, ErrNotFound) // true — sentinel is the cause // JSON and slog work automatically b, _ := json.Marshal(ErrNotFound) // {"error":"resource not found","code":"not_found"} slog.Error("lookup failed", "err", ErrNotFound) ``` > **`Const` vs `errmgr.Define`** > `errors.Const` — static comparable value for `errors.Is` matching. > `errmgr.Define` — parameterised factory that creates new `*Error` instances from a format template. ### Type assertions — Is / As ```go // Is — checks identity or name match err := errors.Named("AuthError") wrapped := errors.Wrapf(err, "login failed") errors.Is(wrapped, err) // true // As — extract the first matching *Error from the chain var target *errors.Error if errors.As(wrapped, &target) { fmt.Println(target.Name()) // "AuthError" } // Generic helpers (Go 1.18+) if e, ok := errors.AsType[*MyError](err); ok { ... } if errors.IsType[*MyError](err) { ... } found, ok := errors.FindType(err, func(e *MyError) bool { return e.Code() == 404 }) codes := errors.Map(err, func(e *MyError) int { return e.Code() }) errors.Filter[*MyError](err) // [] *MyError from chain errors.FirstOfType[*MyError](err) // first *MyError ``` > **`Is()` string-equality note** — `(*Error).Is` falls back to string comparison as a convenience for matching stdlib errors by message. For strict identity matching use `Const()`. ### Multi-error aggregation ```go // Basic m := errors.NewMultiError() m.Add(errors.New("name required")) m.Add(errors.New("email invalid")) fmt.Println(m.Count()) // 2 // With limits and sampling m := errors.NewMultiError( errors.WithLimit(100), errors.WithSampling(10), // 10% sample rate ) // Custom formatter m := errors.NewMultiError( errors.WithFormatter(func(errs []error) string { return fmt.Sprintf("%d errors", len(errs)) }), ) // Inspect m.First() // first error m.Last() // last error m.Errors() // []error snapshot m.Has() // bool m.Single() // nil | first error | *MultiError // Filter networkErrs := m.Filter(func(e error) bool { return strings.Contains(e.Error(), "network") }) // Merge two MultiErrors m.Merge(other) // Join is a convenience that collapses errors to *MultiError or nil err := errors.Join(err1, err2, err3) ``` ### Retry ```go retry := errors.NewRetry( errors.WithMaxAttempts(5), errors.WithDelay(200*time.Millisecond), errors.WithMaxDelay(2*time.Second), errors.WithJitter(true), errors.WithBackoff(errors.ExponentialBackoff{}), errors.WithRetryIf(errors.IsRetryable), errors.WithOnRetry(func(attempt int, err error) { log.Printf("attempt %d: %v", attempt, err) }), ) err := retry.Execute(func() error { return callExternalService() }) // Generic version — preserves return value result, err := errors.ExecuteReply[string](retry, func() (string, error) { return fetchData() }) // Context-aware ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() retry2 := retry.Transform(errors.WithContext(ctx)) err = retry2.Execute(fn) // Backoff strategies errors.ConstantBackoff{} errors.LinearBackoff{} errors.ExponentialBackoff{} ``` ### Chain execution Sequential steps with per-step retry, timeout, tagging, and optional steps. ```go chain := errors.NewChain( errors.ChainWithTimeout(10*time.Second), errors.ChainWithLogHandler(slog.Default().Handler()), ). Step(validateInput).Tag("validation"). Step(verifyKYC).Tag("kyc"). Step(processPayment).Tag("billing").Code(402). Retry(3, 100*time.Millisecond, errors.WithRetryIf(errors.IsRetryable)). Step(sendNotification).Tag("notification").Optional() if err := chain.Run(); err != nil { errors.Inspect(err, os.Stderr) } // Run all steps, collect every error if err := chain.RunAll(); err != nil { errors.Inspect(err, os.Stderr) } ``` `StepCtx` passes the chain-level context (with its deadline) to the step, so blocking calls like HTTP or database queries respect the chain timeout: ```go chain.StepCtx(func(ctx context.Context) error { req, _ := http.NewRequestWithContext(ctx, "GET", url, nil) _, err := http.DefaultClient.Do(req) return err }) ``` ### Channel utilities and streaming #### `<-chan error` utilities These compose with the standard Go `(chan T, chan error)` idiom rather than replacing it. ```go // Drain — block until channel closes, collect into *MultiError err := errors.Drain(errs) // First — return first non-nil error; ctx for deadline only, caller owns cancel err := errors.First(ctx, errs) if err != nil { cancel() // caller decides to stop siblings } // Collect — bounded sample; wraps ErrLimitReached when n is hit err := errors.Collect(ctx, errs, 10) if errors.Is(err, errors.ErrLimitReached) { log.Warn("more than 10 errors — some dropped") } // Fan — merge multiple error channels; caller must drain or cancel to avoid leak merged := errors.Fan(ctx, validateErrs, enrichErrs) for err := range merged { log.Println(err) } ``` #### Stream — concurrent item processing ```go // Process items concurrently, collect all errors s := errors.NewStream(ctx, urls, func(url string) error { return fetch(url) }, 8) // 8 workers; omit for len(items) workers // Option A — block until done if err := s.Wait(); err != nil { errors.Inspect(err, os.Stderr) } // Option B — process errors as they arrive s.Each(func(err error) { log.Println(err) }) // Stop early (drains channel to avoid goroutine leak) s.Stop() ``` `Wait` and `Each` are mutually exclusive. Calling either a second time panics immediately. ### HTTP helpers ```go // Resolve HTTP status from an *Error's code status := errors.HTTPStatusCode(err, http.StatusInternalServerError) // Write HTTP error response errors.HTTPError(w, err) // plain text, status from err.Code() // With options errors.HTTPError(w, err, errors.WithFallbackCode(http.StatusBadGateway), errors.WithBody(false), // header only errors.WithBodyFunc(func(e error) string { return fmt.Sprintf(`{"error":%q}`, e.Error()) }), ) ``` ### Concurrent group `Group` collects all errors from concurrent goroutines — unlike `errgroup` which stops at the first. ```go g := errors.NewGroup() g.Go(func() error { return validateUser(id) }) g.Go(func() error { return validatePerms(id) }) if err := g.Wait(); err != nil { // err is *MultiError containing every failure errors.Inspect(err, os.Stderr) } // Context-aware g := errors.NewGroup( errors.GroupWithContext(ctx, true), // cancelOnFirst=true errors.GroupWithLimit(50), ) g.GoCtx(func(ctx context.Context) error { return longRunningCheck(ctx) }) _ = g.Wait() ``` ### Inspect ```go // Default — writes to os.Stderr errors.Inspect(err) // Targeted output var buf bytes.Buffer errors.Inspect(err, &buf) // Multiple destinations errors.Inspect(err, os.Stderr, logFile) // Options errors.Inspect(err, os.Stderr, errors.WithStackFrames(5), errors.WithMaxDepth(20), ) // *Error-specific convenience errors.InspectError(err, os.Stderr) ``` `Inspect` handles `*Error`, `*MultiError`, and any stdlib error. It writes to the supplied `io.Writer` values (merged via `io.MultiWriter`) and never touches stdout. ### slog integration Both `*Error` and `*Sentinel` implement `slog.LogValuer`: ```go slog.Error("request failed", "err", err) // produces structured group: err.message, err.name, err.code, err.category, err.context, err.cause slog.Error("lookup failed", "err", errors.ErrNotFound) // produces: err.error="resource not found", err.code="not_found" ``` ### Pool management ```go // Pre-warm (called automatically at init with 100 instances) errors.WarmPool(1000) errors.WarmStackPool(500) // Tune global config errors.Configure(errors.Config{ StackDepth: 32, ContextSize: 4, DisablePooling: false, FilterInternal: true, AutoFree: false, // opt-in GC-based pool return }) // Explicit pool return (preferred) err := errors.New("temp") defer err.Free() // Copy without affecting original copied := err.Copy().With("extra", "data") // Transform (non-destructive) enriched := errors.Transform(err, func(e *errors.Error) { e.WithCode(500).With("env", "prod").WithStack() }) ``` --- ## Management — `errmgr` ### Parameterised error templates ```go // Define a reusable template var ErrDBQuery = errmgr.Define("DBQuery", "database query failed: %s") // Instantiate with arguments err := ErrDBQuery("SELECT timed out") fmt.Println(err) // "database query failed: SELECT timed out" fmt.Println(err.Category()) // "database" ``` ### Predefined errors ```go err := errmgr.ErrNotFound fmt.Println(err.Code()) // 404 err := errmgr.ErrDBQuery("SELECT failed") ``` ### Threshold monitoring ```go netErr := errmgr.Define("NetError", "network issue: %s") monitor := errmgr.NewMonitor("NetError") errmgr.SetThreshold("NetError", 3) defer monitor.Close() go func() { for alert := range monitor.Alerts() { fmt.Printf("alert: %s (count: %d)\n", alert, alert.Count()) } }() err := netErr("timeout") err.Free() ``` --- Key design decisions: - **Pool** — `New` and `Wrap` reuse `*Error` instances from `sync.Pool` (12 ns/op, 0 allocs). - **Hybrid context** — up to 4 key-value pairs in a fixed array; overflow to map. Avoids heap allocation for the common case. - **Stack capture** — `captureStack` is inlining-immune: it always starts from `runtime.Callers` frame 1 and trims by array slicing, so the compiler's inlining decisions never corrupt the skip count. - **Pool capacity preservation** — the pool buffer is trimmed in-place (`copy(buf, buf[trimmed:n])`), not re-allocated. Prevents progressive capacity shrinkage under repeated `Free()` cycles. - **`MarshalJSON`** — bytes are copied out of the pool buffer before returning it, eliminating the race between concurrent JSON serialisations. - **`With()`** — the mutex is acquired once at entry, eliminating the TOCTOU race in the former optimistic read-then-lock path. --- ## Migration guide ### From standard library ```go // Before err := fmt.Errorf("user %s not found: %w", username, cause) // After — same output, plus context, code, and chain traversal err := errors.Newf("user %s not found: %w", username, cause). With("username", username). WithCode(404) ``` ### From `pkg/errors` ```go // Before err := pkgerrors.Wrap(cause, "operation failed") // After err := errors.New("operation failed").Wrap(cause).WithStack() ``` ### Stdlib `errors.Is` / `errors.As` compatibility ```go // Fully compatible — no changes needed if errors.Is(err, io.EOF) { ... } var target *errors.Error if errors.As(err, &target) { fmt.Println(target.Name()) } ``` --- ## FAQ **When should I use `Const` vs `Named`?** `Const` — package-level sentinel for `errors.Is` matching. Returns the same pointer every call, so pointer equality works. `Named` — creates a new `*Error` instance each call; useful for structured errors with context but not for `==` comparison. **When should I use `Const` vs `errmgr.Define`?** `errors.Const("not_found", "resource not found")` creates a static sentinel. `errmgr.Define("DBQuery", "query failed: %s")` creates a parameterised factory — you call it with arguments to produce a new `*Error` each time. **When should I call `Free()`?** In hot paths where the error is short-lived and you want to return it to the pool immediately. For most application code, letting the GC handle it is fine. If `AutoFree` is enabled in `Config`, the GC returns the error automatically — but `defer err.Free()` is more predictable. **Why does `First` not cancel the context?** `context.Context` is immutable — only `context.WithCancel` produces a cancellable context. `First` accepts `ctx` for deadline support only. The pattern is: call `First`, then call `cancel()` yourself if you want to stop siblings. **Why do `Each` and `Wait` on `Stream` panic on second call?** Consuming the same channel twice silently splits errors between two callers. The panic surfaces the bug immediately rather than letting it produce subtly wrong results in production. **How do I debug a deep error chain?** ```go errors.Inspect(err, os.Stderr, errors.WithMaxDepth(30), errors.WithStackFrames(10)) ``` **How do I write to both stderr and a log file?** ```go errors.Inspect(err, os.Stderr, logFile) // io.MultiWriter internally ``` --- ## Contributing Fork → branch → commit → PR. Please include tests for new behaviour and run `go test -count=10 -race ./...` before opening a PR. ## License MIT — see [LICENSE](LICENSE).golang-github-olekukonko-errors-1.3.0/_examples/000077500000000000000000000000001517267734700217335ustar00rootroot00000000000000golang-github-olekukonko-errors-1.3.0/_examples/basic.go000066400000000000000000000026641517267734700233530ustar00rootroot00000000000000// Package main demonstrates basic usage of the errors package from github.com/olekukonko/errors. // It showcases creating simple errors, formatted errors, errors with stack traces, and named errors, // highlighting the package’s enhanced error handling capabilities. package main import ( "fmt" "github.com/olekukonko/errors" ) // main is the entry point of the program, illustrating various ways to create and use errors // from the errors package, printing their outputs to demonstrate their behavior. func main() { // Simple error (no stack trace, fast) // Creates a lightweight error without capturing a stack trace for optimal performance. err := errors.New("connection failed") fmt.Println(err) // Outputs: "connection failed" // Formatted error // Creates an error with a formatted message using a printf-style syntax, similar to fmt.Errorf. err = errors.Newf("user %s not found", "bob") fmt.Println(err) // Outputs: "user bob not found" // Error with stack trace // Creates an error and captures a stack trace at the point of creation for debugging purposes. err = errors.Trace("critical issue") fmt.Println(err) // Outputs: "critical issue" fmt.Println(err.Stack()) // Outputs stack trace, e.g., ["main.go:15", "caller.go:42"] // Named error // Creates an error with a specific name and stack trace, useful for categorizing errors. err = errors.Named("InputError") fmt.Println(err.Name()) // Outputs: "InputError" } golang-github-olekukonko-errors-1.3.0/_examples/chains.go000066400000000000000000000056331517267734700235360ustar00rootroot00000000000000// Package main demonstrates error wrapping and chaining using the errors package from // github.com/olekukonko/errors. It simulates a layered application (API, business logic, // database) where errors are created, enriched with context, and wrapped, then inspected // to show the full error chain and specific error checks. package main import ( "fmt" "github.com/olekukonko/errors" ) // databaseQuery simulates a database operation that fails, returning an error with context. // It represents the lowest layer where an error originates, enriched with database-specific details. func databaseQuery() error { // Create a base error with database failure details return errors.New("connection timeout"). With("timeout_sec", 5). // Add timeout duration context With("server", "db01.prod") // Add server identifier context } // businessLogic processes a user request and handles database errors, wrapping them with business context. // It represents an intermediate layer that adds its own context without modifying the original error. func businessLogic(userID string) error { err := databaseQuery() if err != nil { // Create a new error specific to business logic failure return errors.New("failed to process user "+userID). With("user_id", userID). // Add user ID context With("stage", "processing"). // Add processing stage context Wrap(err) // Wrap the database error for chaining } return nil } // apiHandler simulates an API request handler that wraps business logic errors with API context. // It represents the top layer, adding a status code and stack trace for debugging. func apiHandler() error { err := businessLogic("12345") if err != nil { // Create a new error specific to API failure return errors.New("API request failed"). WithCode(500). // Set HTTP-like status code WithStack(). // Capture stack trace at API level Wrap(err) // Wrap the business logic error } return nil } // main is the entry point, demonstrating error creation, wrapping, and inspection. // It prints the combined error message, unwraps the error chain, and checks for a specific error. func main() { err := apiHandler() // Print the full error message combining all wrapped errors fmt.Println("=== Combined Message ===") fmt.Println(err) // Unwrap and display each error in the chain with its details fmt.Println("\n=== Error Chain ===") for i, e := range errors.UnwrapAll(err) { fmt.Printf("%d. %T\n", i+1, e) // Show error index and type if err, ok := e.(*errors.Error); ok { fmt.Println(err.Format()) // Print formatted details for custom errors } else { fmt.Println(e) // Print standard error message for non-custom errors } } // Check if the error chain contains a specific error fmt.Println("\n=== Error Checks ===") if errors.Is(err, errors.New("connection timeout")) { fmt.Println("✓ Matches connection timeout error") // Confirm match with database error } } golang-github-olekukonko-errors-1.3.0/_examples/context.go000066400000000000000000000057731517267734700237620ustar00rootroot00000000000000// Package main demonstrates the use of context in the errors package from // github.com/olekukonko/errors. It showcases adding context to custom errors, // accessing context through wrapped and converted errors, handling standard library // errors, and working with complex context data structures. package main import ( "fmt" "github.com/olekukonko/errors" ) // processData simulates a data processing operation that fails, returning an error with context. // It attaches retry-related metadata to the error for demonstration purposes. func processData(id string, attempt int) error { // Create an error with processing-specific context return errors.New("processing failed"). With("id", id). // Add data identifier With("attempt", attempt). // Add attempt number With("retryable", true) // Mark as retryable } // main is the entry point, illustrating various ways to work with error context. // It demonstrates basic context addition, context preservation through wrapping, // handling standard errors, and managing complex context data. func main() { // 1. Basic context example // Create and display an error with simple key-value context err := processData("123", 3) fmt.Println("Error:", err) // Print error message fmt.Println("Full context:", errors.Context(err)) // Print all context as a map // 2. Accessing context through conversion // Wrap the error with fmt.Errorf and show context preservation rawErr := fmt.Errorf("wrapped: %w", err) fmt.Println("\nAfter wrapping with fmt.Errorf:") fmt.Println("Direct context access:", errors.Context(rawErr)) // Show context is unavailable directly e := errors.Convert(rawErr) fmt.Println("After conversion - context:", e.Context()) // Show context restored via conversion // 3. Standard library errors // Demonstrate that standard errors lack context stdErr := fmt.Errorf("standard error") if errors.Context(stdErr) == nil { fmt.Println("\nStandard library errors have no context") // Confirm no context exists } // 4. Adding context to standard errors // Convert a standard error and enrich it with context converted := errors.Convert(stdErr). With("source", "legacy"). // Add source information With("severity", "high") // Add severity level fmt.Println("\nConverted standard error:") fmt.Println("Message:", converted.Error()) // Print original message fmt.Println("Context:", converted.Context()) // Print added context // 5. Complex context example // Create an error with nested and varied context data complexErr := errors.New("database operation failed"). With("query", "SELECT * FROM users"). // Add SQL query string With("params", map[string]interface{}{ "limit": 100, // Nested parameter: result limit "offset": 0, // Nested parameter: result offset }). With("duration_ms", 45.2) // Add execution time in milliseconds fmt.Println("\nComplex error context:") for k, v := range errors.Context(complexErr) { fmt.Printf("%s: %v (%T)\n", k, v, v) // Print each context key-value pair with type } } golang-github-olekukonko-errors-1.3.0/_examples/multi_error.go000066400000000000000000000167111517267734700246330ustar00rootroot00000000000000// Package main demonstrates the use of MultiError from github.com/olekukonko/errors to handle // multiple validation and system errors. It showcases form validation with custom formatting, // error filtering, and system error aggregation with retryable conditions, illustrating error // management in a user registration and system monitoring context. package main import ( "fmt" "net/mail" "strings" "time" "github.com/olekukonko/errors" ) // UserForm represents a user registration form with fields to validate. type UserForm struct { Name string Email string Password string Birthday string } // validateUser validates a UserForm and returns a MultiError containing all validation failures. // It checks name, email, password, and birthday fields, accumulating errors with a custom limit. func validateUser(form UserForm) *errors.MultiError { // Initialize a MultiError with a limit of 10 errors and custom formatting multi := errors.NewMultiError( errors.WithLimit(10), // Cap the number of errors at 10 errors.WithFormatter(customFormat), // Use custom validation error format ) // Name validation if form.Name == "" { multi.Add(errors.New("name is required")) // Add error for empty name } else if len(form.Name) > 50 { multi.Add(errors.New("name cannot exceed 50 characters")) // Add error for long name } // Email validation if form.Email == "" { multi.Add(errors.New("email is required")) // Add error for empty email } else { if _, err := mail.ParseAddress(form.Email); err != nil { multi.Add(errors.New("invalid email format")) // Add error for invalid email } if !strings.Contains(form.Email, "@") { multi.Add(errors.New("email must contain @ symbol")) // Add error for missing @ } } // Password validation if len(form.Password) < 8 { multi.Add(errors.New("password must be at least 8 characters")) // Add error for short password } if !strings.ContainsAny(form.Password, "0123456789") { multi.Add(errors.New("password must contain at least one number")) // Add error for no digits } if !strings.ContainsAny(form.Password, "!@#$%^&*") { multi.Add(errors.New("password must contain at least one special character")) // Add error for no special chars } // Birthday validation if form.Birthday != "" { if _, err := time.Parse("2006-01-02", form.Birthday); err != nil { multi.Add(errors.New("birthday must be in YYYY-MM-DD format")) // Add error for invalid date format } else if bday, _ := time.Parse("2006-01-02", form.Birthday); time.Since(bday).Hours()/24/365 < 13 { multi.Add(errors.New("must be at least 13 years old")) // Add error for age under 13 } } return multi } // customFormat formats a slice of validation errors into a user-friendly string. // It adds a header, numbered list, and total count for display purposes. func customFormat(errs []error) string { var sb strings.Builder sb.WriteString("🚨 Validation Errors:\n") // Add header with emoji for i, err := range errs { sb.WriteString(fmt.Sprintf(" %d. %s\n", i+1, err)) // List each error with number } sb.WriteString(fmt.Sprintf("\nTotal issues found: %d\n", len(errs))) // Append total count return sb.String() } // main is the entry point, demonstrating MultiError usage for validation and system errors. // It validates a user form, analyzes errors, and aggregates system errors with retryable filtering. func main() { fmt.Println("=== User Registration Validation ===") // Define a user form with intentional validation failures user := UserForm{ Name: "", // Empty name to trigger error Email: "invalid-email", // Invalid email format Password: "weak", // Weak password Birthday: "2015-01-01", // Date making user under 13 } // Generate and display validation errors validationErrors := validateUser(user) if validationErrors.Has() { fmt.Println(validationErrors) // Print all validation errors // Detailed error analysis fmt.Println("\n🔍 Error Analysis:") fmt.Printf("Total errors: %d\n", validationErrors.Count()) // Show total error count fmt.Printf("First error: %v\n", validationErrors.First()) // Show first error fmt.Printf("Last error: %v\n", validationErrors.Last()) // Show last error // Categorize and display errors with consistent formatting fmt.Println("\n📋 Error Categories:") if emailErrors := validationErrors.Filter(contains("email")); emailErrors.Has() { fmt.Println("Email Issues:") if emailErrors.Count() == 1 { fmt.Println(customFormat([]error{emailErrors.First()})) // Format single email error } else { fmt.Println(emailErrors) // Print multiple email errors } } if pwErrors := validationErrors.Filter(contains("password")); pwErrors.Has() { fmt.Println("Password Issues:") if pwErrors.Count() == 1 { fmt.Println(customFormat([]error{pwErrors.First()})) // Format single password error } else { fmt.Println(pwErrors) // Print multiple password errors } } if ageErrors := validationErrors.Filter(contains("13 years")); ageErrors.Has() { fmt.Println("Age Restriction:") if ageErrors.Count() == 1 { fmt.Println(customFormat([]error{ageErrors.First()})) // Format single age error } else { fmt.Println(ageErrors) // Print multiple age errors } } } // System Error Aggregation Example fmt.Println("\n=== System Error Aggregation ===") // Initialize a MultiError for system errors with a limit and custom format systemErrors := errors.NewMultiError( errors.WithLimit(5), // Cap at 5 errors errors.WithFormatter(systemErrorFormat), // Use system error formatting ) // Simulate various system errors systemErrors.Add(errors.New("database connection timeout").WithRetryable()) // Add retryable DB error systemErrors.Add(errors.New("API rate limit exceeded").WithRetryable()) // Add retryable API error systemErrors.Add(errors.New("disk space low")) // Add non-retryable error systemErrors.Add(errors.New("database connection timeout").WithRetryable()) // Add duplicate DB error systemErrors.Add(errors.New("cache miss")) // Add another error systemErrors.Add(errors.New("database connection timeout").WithRetryable()) // Add over limit, ignored fmt.Println(systemErrors) // Print system errors fmt.Printf("\nSystem Status: %d active issues\n", systemErrors.Count()) // Show active error count // Filter and display retryable errors if retryable := systemErrors.Filter(errors.IsRetryable); retryable.Has() { fmt.Println("\n🔄 Retryable Errors:") fmt.Println(retryable) // Print only retryable errors } } // systemErrorFormat formats a slice of system errors with retryable indicators. // It creates a numbered list with a header, marking retryable errors explicitly. func systemErrorFormat(errs []error) string { var sb strings.Builder sb.WriteString("⚠️ System Alerts:\n") // Add header with emoji for i, err := range errs { sb.WriteString(fmt.Sprintf(" %d. %s", i+1, err)) // List each error with number if errors.IsRetryable(err) { sb.WriteString(" (retryable)") // Mark as retryable if applicable } sb.WriteString("\n") } return sb.String() } // contains returns a predicate function to filter errors containing a substring. // It’s used to categorize errors based on their message content. func contains(substr string) func(error) bool { return func(err error) bool { return strings.Contains(err.Error(), substr) // Check if error message contains substring } } golang-github-olekukonko-errors-1.3.0/_examples/multi_error_formatting.go000066400000000000000000000126251517267734700270650ustar00rootroot00000000000000// Package main demonstrates the use of MultiError with sampling from github.com/olekukonko/errors. // It generates a large number of errors, applies sampling with a limit, and analyzes the results, // showcasing error collection, custom formatting, and statistical reporting in a simulated error-heavy scenario. package main import ( "fmt" "math/rand" "strings" "time" "github.com/olekukonko/errors" ) // main is the entry point, simulating error generation with sampling and reporting statistics. // It creates a MultiError, populates it with sampled errors, and displays detailed analysis. func main() { // Configuration totalErrors := 1000 // Total number of errors to generate sampleRate := 10 // Target sampling rate (10%) errorLimit := 50 // Maximum number of errors to store // Initialize with reproducible seed for demo purposes r := rand.New(rand.NewSource(42)) // Create a seeded random source for consistency start := time.Now() // Record start time for performance measurement // Create MultiError with sampling // Configure MultiError with sampling rate, limit, random source, and custom formatter multi := errors.NewMultiError( errors.WithSampling(uint32(sampleRate)), // Set sampling rate to 10% errors.WithLimit(errorLimit), // Cap stored errors at 50 errors.WithRand(r), // Use seeded random number generator errors.WithFormatter(createFormatter(totalErrors)), // Apply custom formatter with total ) // Generate errors for i := 0; i < totalErrors; i++ { multi.Add(errors.Newf("operation %d failed", i)) // Add formatted error for each iteration } // Calculate statistics duration := time.Since(start) // Calculate elapsed time sampledCount := multi.Count() // Get number of sampled errors actualRate := float64(sampledCount) / float64(totalErrors) * 100 // Compute actual sampling percentage // Print results fmt.Println(multi) // Display sampled errors with custom format printStatistics(totalErrors, sampledCount, sampleRate, actualRate, duration) // Show statistical summary printErrorDistribution(multi, 5) // Show distribution of first 5 errors } // createFormatter returns a formatter for MultiError that includes total error count. // It generates a header for the error report, showing sampled vs. total errors. func createFormatter(total int) errors.ErrorFormatter { return func(errs []error) string { var sb strings.Builder sb.WriteString(fmt.Sprintf("Sampled Error Report (%d/%d):\n", len(errs), total)) // Report sampled vs. total sb.WriteString("══════════════════════════════\n") // Add separator line return sb.String() } } // printStatistics displays statistical summary of error sampling. // It reports total errors, sampled count, rates, duration, and analysis notes. func printStatistics(total, sampled, targetRate int, actualRate float64, duration time.Duration) { fmt.Printf("\nStatistics:\n") fmt.Printf("├─ Total errors generated: %d\n", total) // Show total errors created fmt.Printf("├─ Errors captured: %d (limit: %d)\n", sampled, 50) // Show sampled errors and limit fmt.Printf("├─ Target sampling rate: %d%%\n", targetRate) // Show intended sampling rate fmt.Printf("├─ Actual sampling rate: %.1f%%\n", actualRate) // Show achieved sampling rate fmt.Printf("├─ Processing time: %v\n", duration) // Show time taken // Analyze sampling accuracy and limits switch { case sampled == 50 && actualRate < float64(targetRate): fmt.Printf("└─ Note: Hit storage limit - actual rate would be ~%.1f%% without limit\n", float64(targetRate)) // Note when limit caps sampling case actualRate < float64(targetRate)*0.8 || actualRate > float64(targetRate)*1.2: fmt.Printf("└─ ⚠️ Warning: Significant sampling deviation\n") // Warn on large deviation default: fmt.Printf("└─ Sampling within expected range\n") // Confirm normal sampling } } // printErrorDistribution displays a subset of errors with a progress bar visualization. // It shows up to maxDisplay errors, indicating remaining count if truncated. func printErrorDistribution(m *errors.MultiError, maxDisplay int) { errs := m.Errors() // Get all sampled errors if len(errs) == 0 { return // Skip if no errors } fmt.Printf("\nError Distribution (showing first %d):\n", maxDisplay) // Announce display limit for i, err := range errs { if i >= maxDisplay { fmt.Printf("└─ ... and %d more\n", len(errs)-maxDisplay) // Indicate remaining errors break } fmt.Printf("%s %v\n", getProgressBar(i, len(errs)), err) // Print error with progress bar } } // getProgressBar generates a visual progress bar for error distribution. // It creates a fixed-width bar based on the index relative to total errors. func getProgressBar(index, total int) string { const width = 10 // Set bar width to 10 characters pos := int(float64(index) / float64(total) * width) // Calculate filled portion return fmt.Sprintf("├─%s%s┤", strings.Repeat("■", pos), strings.Repeat(" ", width-pos)) // Build bar with ■ and spaces } golang-github-olekukonko-errors-1.3.0/_examples/null.go000066400000000000000000000051021517267734700232320ustar00rootroot00000000000000// Package main demonstrates the use of the IsNull method from github.com/olekukonko/errors. // It tests various scenarios involving nil errors, empty errors, errors with null or non-null // context, and MultiError instances, showcasing how IsNull determines nullability based on content // and context, particularly with SQL null types. package main import ( "database/sql" "fmt" "github.com/olekukonko/errors" ) // main is the entry point, illustrating the behavior of IsNull across different error cases. // It checks nil errors, empty errors, context with SQL null values, and MultiError instances. func main() { // Case 1: Nil error // Test if a nil error is considered null var err error = nil if errors.IsNull(err) { fmt.Println("Nil error is null") // Expect true: nil errors are always null } // Case 2: Empty error // Test if an empty error (no message) is considered null err = errors.New("") if errors.IsNull(err) { fmt.Println("Empty error is null") } else { fmt.Println("Empty error is not null") // Expect false: empty message but no null context } // Case 3: Error with null context // Test if an error with a null SQL context value is considered null nullString := sql.NullString{Valid: false} err = errors.New("").With("data", nullString) if errors.IsNull(err) { fmt.Println("Error with null context is null") // Expect true: all context is null } // Case 4: Error with non-null context // Test if an error with a valid SQL context value is not null validString := sql.NullString{String: "test", Valid: true} err = errors.New("").With("data", validString) if errors.IsNull(err) { fmt.Println("Error with valid context is null") } else { fmt.Println("Error with valid context is not null") // Expect false: valid context present } // Case 5: Empty MultiError // Test if an empty MultiError is considered null multi := errors.NewMultiError() if multi.IsNull() { fmt.Println("Empty MultiError is null") // Expect true: no errors in MultiError } // Case 6: MultiError with null error // Test if a MultiError containing a null error is considered null multi.Add(errors.New("").With("data", nullString)) if multi.IsNull() { fmt.Println("MultiError with null error is null") // Expect true: only null errors } // Case 7: MultiError with non-null error // Test if a MultiError with mixed errors (null and non-null) is not null multi.Add(errors.New("real error")) if multi.IsNull() { fmt.Println("MultiError with mixed errors is null") } else { fmt.Println("MultiError with mixed errors is not null") // Expect false: contains non-null error } } golang-github-olekukonko-errors-1.3.0/_examples/retry.go000066400000000000000000000130771517267734700234370ustar00rootroot00000000000000// Package main demonstrates the retry functionality of github.com/olekukonko/errors. // It simulates flaky database and external service operations with configurable retries, // exponential backoff, jitter, and context timeouts, showcasing error handling, retry policies, // and result capturing in various failure scenarios. package main import ( "context" "fmt" "math/rand" "time" "github.com/olekukonko/errors" ) // DatabaseClient simulates a flaky database connection with a recovery point. // It fails until a specified number of attempts is reached, then succeeds. type DatabaseClient struct { healthyAfterAttempt int // Number of attempts before becoming healthy } // Query attempts a database operation, failing until healthyAfterAttempt reaches zero. // It returns a retryable error with remaining attempts context during failure. func (db *DatabaseClient) Query() error { if db.healthyAfterAttempt > 0 { db.healthyAfterAttempt-- // Decrement failure counter return errors.New("database connection failed"). With("attempt_remaining", db.healthyAfterAttempt). // Add remaining attempts context WithRetryable() // Mark error as retryable } return nil // Success when attempts exhausted } // ExternalService simulates an unreliable external API with random failures. // It fails 30% of the time, returning a retryable error with a 503 status code. func ExternalService() error { if rand.Intn(100) < 30 { // 30% failure probability return errors.New("service unavailable"). WithCode(503). // Set HTTP 503 Service Unavailable status WithRetryable() // Mark error as retryable } return nil // Success on remaining 70% } // main is the entry point, demonstrating retry scenarios with database, external service, and timeout. // It configures retries with backoff, jitter, and context, executing operations and reporting outcomes. func main() { // Configure retry with exponential backoff and jitter // Set up a retry policy with custom parameters and logging retry := errors.NewRetry( errors.WithMaxAttempts(5), // Allow up to 5 attempts errors.WithDelay(200*time.Millisecond), // Base delay of 200ms errors.WithMaxDelay(2*time.Second), // Cap delay at 2s errors.WithJitter(true), // Add randomness to delays errors.WithBackoff(errors.ExponentialBackoff{}), // Use exponential backoff strategy errors.WithOnRetry(func(attempt int, err error) { // Callback on each retry // Calculate delay for logging, mirroring Execute logic baseDelay := 200 * time.Millisecond maxDelay := 2 * time.Second delay := errors.ExponentialBackoff{}.Backoff(attempt, baseDelay) if delay > maxDelay { delay = maxDelay } fmt.Printf("Attempt %d failed: %v (retrying in %v)\n", attempt, err.Error(), delay) }), ) // Scenario 1: Database connection with known recovery point // Test retrying a database operation that recovers after 3 failures db := &DatabaseClient{healthyAfterAttempt: 3} fmt.Println("Starting database operation...") err := retry.Execute(func() error { return db.Query() // Attempt database query }) if err != nil { fmt.Printf("Database operation failed after %d attempts: %v\n", retry.Attempts(), err) } else { fmt.Println("Database operation succeeded!") // Expect success after 4 attempts } // Scenario 2: External service with random failures // Test retrying an external service call with a 30% failure rate fmt.Println("\nStarting external service call...") var lastAttempts int // Track total attempts manually start := time.Now() // Measure duration // Using ExecuteReply to capture both result and error result, err := errors.ExecuteReply[string](retry, func() (string, error) { lastAttempts++ // Increment attempt counter if err := ExternalService(); err != nil { return "", err // Return error on failure } return "service response data", nil // Return success data }) duration := time.Since(start) // Calculate elapsed time if err != nil { fmt.Printf("Service call failed after %d attempts (%.2f sec): %v\n", lastAttempts, duration.Seconds(), err) } else { fmt.Printf("Service call succeeded after %d attempts (%.2f sec): %s\n", lastAttempts, duration.Seconds(), result) // Expect variable attempts } // Scenario 3: Context cancellation with more visibility // Test retrying an operation with a short timeout fmt.Println("\nStarting operation with timeout...") ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) // 500ms timeout defer cancel() // Ensure context cleanup // Transform retry configuration with context and increased visibility timeoutRetry := retry.Transform( errors.WithContext(ctx), // Apply timeout context errors.WithMaxAttempts(10), // Increase to 10 attempts errors.WithOnRetry(func(attempt int, err error) { // Log each retry attempt fmt.Printf("Timeout scenario attempt %d: %v\n", attempt, err) }), ) startTimeout := time.Now() // Measure timeout scenario duration err = timeoutRetry.Execute(func() error { time.Sleep(300 * time.Millisecond) // Simulate a long operation return errors.New("operation timed out") // Return consistent error }) if errors.Is(err, context.DeadlineExceeded) { fmt.Printf("Operation cancelled by timeout after %.2f sec: %v\n", time.Since(startTimeout).Seconds(), err) // Expect timeout cancellation } else if err != nil { fmt.Printf("Operation failed: %v\n", err) } else { fmt.Println("Operation succeeded (unexpected)") // Unlikely with 500ms timeout } } golang-github-olekukonko-errors-1.3.0/_examples/withstack.go000066400000000000000000000060051517267734700242640ustar00rootroot00000000000000// Package main demonstrates the use of WithStack from github.com/olekukonko/errors. // It showcases adding stack traces to errors using both package-level and method-style approaches, // comparing their application to standard and enhanced errors, and combining them in a real-world // scenario with additional context and error details. package main import ( "fmt" "time" "github.com/olekukonko/errors" "math/rand" ) // basicFunc simulates a simple function returning a standard error. // It represents a legacy or external function without enhanced error features. func basicFunc() error { return fmt.Errorf("basic error") // Return a basic fmt.Errorf error } // enhancedFunc simulates a function returning an enhanced *errors.Error. // It represents a function utilizing the errors package's custom error type. func enhancedFunc() *errors.Error { return errors.New("enhanced error") // Return a new *errors.Error } // main is the entry point, demonstrating WithStack usage in various contexts. // It tests package-level WithStack on standard errors, method-style WithStack on enhanced errors, // and a combined approach in a practical scenario. func main() { // 1. Package-level WithStack - works with ANY error type // Demonstrate adding a stack trace to a standard error err1 := basicFunc() enhanced1 := errors.WithStack(err1) // Convert and add stack trace to any error fmt.Println("Package-level WithStack:") fmt.Println(enhanced1.Stack()) // Print stack trace from standard error // 2. Method-style WithStack - only for *errors.Error // Show adding a stack trace to an enhanced error using method chaining err2 := enhancedFunc() enhanced2 := err2.WithStack() // Add stack trace to *errors.Error via method fmt.Println("\nMethod-style WithStack:") fmt.Println(enhanced2.Stack()) // Print stack trace from enhanced error // 3. Combined usage in real-world scenario // Test a mixed error type with both WithStack approaches and additional context result := processData() if result != nil { // Use package-level WithStack when error type is unknown stackErr := errors.WithStack(result) // Chain method-style enhancements on the resulting *errors.Error finalErr := stackErr. With("timestamp", time.Now()). // Add timestamp context WithCode(500) // Set HTTP-like status code fmt.Println("\nCombined Usage:") fmt.Println("Message:", finalErr.Error()) // Print full error message fmt.Println("Context:", finalErr.Context()) // Print context map fmt.Println("Stack:") for _, frame := range finalErr.Stack() { fmt.Println(frame) // Print each stack frame } } } // processData simulates a data processing function with variable error types. // It randomly returns either a standard error or an enhanced error with context. func processData() error { // Randomly choose between standard and enhanced error if rand.Intn(2) == 0 { return fmt.Errorf("database error") // Return standard error } return errors.New("validation error").With("field", "email") // Return enhanced error with context } golang-github-olekukonko-errors-1.3.0/base.go000066400000000000000000000124041517267734700212200ustar00rootroot00000000000000package errors import ( "bytes" "fmt" "regexp" "sync" ) // Constants defining default configuration and context keys. const ( ctxTimeout = "[error] timeout" // Context key marking timeout errors. ctxRetry = "[error] retry" // Context key marking retryable errors. contextSize = 4 // Initial size of fixed-size context array for small contexts. bufferSize = 256 // Initial buffer size for JSON marshaling. warmUpSize = 100 // Number of errors to pre-warm the pool for efficiency. stackDepth = 32 // Maximum stack trace depth to prevent excessive memory use. DefaultCode = 500 // Default HTTP status code for errors if not specified. ) // spaceRe is a precompiled regex for normalizing whitespace in error messages. var spaceRe = regexp.MustCompile(`\s+`) // jsonBufferPool manages reusable buffers for JSON marshaling to reduce allocations. var ( jsonBufferPool = sync.Pool{ New: func() interface{} { return bytes.NewBuffer(make([]byte, 0, bufferSize)) }, } ) // ErrorCategory is a string type for categorizing errors (e.g., "network", "validation"). type ErrorCategory string // ErrorOpts provides options for customizing error creation. type ErrorOpts struct { SkipStack int // Number of stack frames to skip when capturing the stack trace. } // Config defines the global configuration for the errors package, controlling // stack depth, context size, pooling, and frame filtering. type Config struct { StackDepth int // Maximum stack trace depth; 0 uses default (32). ContextSize int // Initial context map size; 0 uses default (4). DisablePooling bool // If true, disables object pooling for errors. FilterInternal bool // If true, filters internal package frames from stack traces. AutoFree bool // If true, automatically returns errors to pool when GC collects them. } // cachedConfig holds the current configuration, updated only by Configure(). // Protected by configMu for thread-safety. type cachedConfig struct { stackDepth int contextSize int disablePooling bool filterInternal bool autoFree bool } var ( // currentConfig stores the active configuration, read frequently and updated rarely. currentConfig cachedConfig // configMu protects updates to currentConfig for thread-safety. configMu sync.RWMutex // errorPool manages reusable Error instances to reduce allocations. errorPool = NewErrorPool() // stackPool manages reusable stack trace slices for efficiency. stackPool = sync.Pool{ New: func() interface{} { return make([]uintptr, currentConfig.stackDepth) }, } // emptyError is a pre-allocated empty error for lightweight reuse. emptyError = &Error{ smallContext: [contextSize]contextItem{}, msg: "", name: "", template: "", cause: nil, } ) // contextItem holds a single key-value pair in the smallContext array. type contextItem struct { key string value interface{} } // init sets up the package with default configuration and pre-warms the error pool. func init() { currentConfig = cachedConfig{ stackDepth: stackDepth, contextSize: contextSize, disablePooling: false, filterInternal: true, autoFree: false, // opt-in; explicit Free() is the safe default } WarmPool(warmUpSize) // Pre-allocate errors for performance. } // Configure updates the global configuration for the errors package. // It is thread-safe and should be called early to avoid race conditions. // Changes apply to all subsequent error operations. // Example: // // errors.Configure(errors.Config{StackDepth: 16, DisablePooling: true}) func Configure(cfg Config) { configMu.Lock() defer configMu.Unlock() if cfg.StackDepth != 0 { currentConfig.stackDepth = cfg.StackDepth } if cfg.ContextSize != 0 { currentConfig.contextSize = cfg.ContextSize } currentConfig.disablePooling = cfg.DisablePooling currentConfig.filterInternal = cfg.FilterInternal currentConfig.autoFree = cfg.AutoFree } // WarmPool pre-populates the error pool with count instances. // Improves performance by reducing initial allocations. // No-op if pooling is disabled. // Example: // // errors.WarmPool(1000) func WarmPool(count int) { if currentConfig.disablePooling { return } for i := 0; i < count; i++ { e := &Error{ smallContext: [contextSize]contextItem{}, stack: nil, } errorPool.Put(e) stackPool.Put(make([]uintptr, 0, currentConfig.stackDepth)) } } // WarmStackPool pre-populates the stack pool with count slices. // Improves performance for stack-intensive operations. // No-op if pooling is disabled. // Example: // // errors.WarmStackPool(500) func WarmStackPool(count int) { if currentConfig.disablePooling { return } for i := 0; i < count; i++ { stackPool.Put(make([]uintptr, 0, currentConfig.stackDepth)) } } // FmtErrorCheck safely formats a string using fmt.Sprintf, catching panics. // Returns the formatted string and any error encountered. // Internal use by Newf to validate format strings. // Example: // // result, err := FmtErrorCheck("value: %s", "test") func FmtErrorCheck(format string, args ...interface{}) (result string, err error) { defer func() { if r := recover(); r != nil { if e, ok := r.(error); ok { err = e } else { err = fmt.Errorf("panic during formatting: %v", r) } } }() result = fmt.Sprintf(format, args...) return result, nil } golang-github-olekukonko-errors-1.3.0/chain.go000066400000000000000000000506371517267734700214020ustar00rootroot00000000000000package errors import ( "context" "fmt" "log/slog" // Standard structured logging package "reflect" "strings" "sync" "time" ) // Chain executes functions sequentially with enhanced error handling. // Logging is optional and configured via a slog.Handler. type Chain struct { steps []chainStep // List of steps to execute errors []error // Accumulated errors during execution config chainConfig // Chain-wide configuration lastStep *chainStep // Pointer to the last added step for configuration logHandler slog.Handler // Optional logging handler (nil means no logging) cancel context.CancelFunc // Function to cancel the context runCtx context.Context // Active context for Run/RunAll; shared with StepCtx closures configMu sync.RWMutex // Protects chainConfig against concurrent Timeout() calls } // chainStep represents a single step in the chain. type chainStep struct { execute func() error // Function to execute for this step optional bool // If true, errors don't stop the chain config stepConfig // Step-specific configuration } // chainConfig holds chain-wide settings. type chainConfig struct { timeout time.Duration // Maximum duration for the entire chain maxErrors int // Maximum number of errors before stopping (-1 for unlimited) autoWrap bool // Whether to automatically wrap errors with additional context } // stepConfig holds configuration for an individual step. type stepConfig struct { context map[string]interface{} // Arbitrary key-value pairs for context category ErrorCategory // Category for error classification code int // Numeric error code retry *Retry // Retry policy for the step logOnFail bool // Whether to log errors automatically metricsLabel string // Label for metrics (not used in this code) logAttrs []slog.Attr // Additional attributes for logging } // ChainOption defines a function that configures a Chain. type ChainOption func(*Chain) // NewChain creates a new Chain with the given options. // Logging is disabled by default (logHandler is nil). func NewChain(opts ...ChainOption) *Chain { c := &Chain{ config: chainConfig{ autoWrap: true, // Enable error wrapping by default maxErrors: -1, // No limit on errors by default }, // logHandler is nil, meaning no logging unless explicitly configured } // Apply each configuration option for _, opt := range opts { opt(c) } return c } // ChainWithLogHandler sets a custom slog.Handler for logging. // If handler is nil, logging is effectively disabled. func ChainWithLogHandler(handler slog.Handler) ChainOption { return func(c *Chain) { c.logHandler = handler } } // ChainWithTimeout sets a timeout for the entire chain. func ChainWithTimeout(d time.Duration) ChainOption { return func(c *Chain) { c.config.timeout = d } } // ChainWithMaxErrors sets the maximum number of errors allowed. // A value <= 0 means no limit. func ChainWithMaxErrors(max int) ChainOption { return func(c *Chain) { if max <= 0 { c.config.maxErrors = -1 // No limit } else { c.config.maxErrors = max } } } // ChainWithAutoWrap enables or disables automatic error wrapping. func ChainWithAutoWrap(auto bool) ChainOption { return func(c *Chain) { c.config.autoWrap = auto } } // Step adds a new step to the chain with the provided function. // The function must return an error or nil. func (c *Chain) Step(fn func() error) *Chain { if fn == nil { // Panic to enforce valid input panic("Chain.Step: provided function cannot be nil") } // Create a new step with default configuration step := chainStep{execute: fn, config: stepConfig{}} c.steps = append(c.steps, step) // Update lastStep to point to the newly added step c.lastStep = &c.steps[len(c.steps)-1] return c } // StepCtx adds a context-aware step to the chain. The provided function // receives the chain's context (which carries any chain-level deadline/timeout), // so cancellation and timeouts propagate correctly into blocking operations such // as HTTP requests, database queries, or gRPC calls. // // StepCtx is the context-safe alternative to Step; existing Step calls are // unchanged and fully compatible. // // Example: // // chain := NewChain(ChainWithTimeout(5 * time.Second)). // StepCtx(func(ctx context.Context) error { // req, _ := http.NewRequestWithContext(ctx, "GET", url, nil) // _, err := http.DefaultClient.Do(req) // return err // }) func (c *Chain) StepCtx(fn func(ctx context.Context) error) *Chain { if fn == nil { panic("Chain.StepCtx: provided function cannot be nil") } // Wrap fn so it satisfies the internal func() error signature used by // executeStep. The context is captured at execution time via getContextAndCancel. // Close over c.runCtx — set by Run/RunAll to the chain-level context. // This ensures StepCtx steps share the same deadline as the chain, // rather than each getting a fresh full-duration context. wrapped := func() error { ctx := c.runCtx if ctx == nil { ctx = context.Background() } return fn(ctx) } step := chainStep{execute: wrapped, config: stepConfig{}} c.steps = append(c.steps, step) c.lastStep = &c.steps[len(c.steps)-1] return c } // Call adds a step by wrapping a function with arguments. // It uses reflection to validate and invoke the function. func (c *Chain) Call(fn interface{}, args ...interface{}) *Chain { // Wrap the function and arguments into an executable step wrappedFn, err := c.wrapCallable(fn, args...) if err != nil { // Panic on setup errors to catch them early panic(fmt.Sprintf("Chain.Call setup error: %v", err)) } // Add the wrapped function as a step step := chainStep{execute: wrappedFn, config: stepConfig{}} c.steps = append(c.steps, step) c.lastStep = &c.steps[len(c.steps)-1] return c } // Optional marks the last step as optional. // Optional steps don't stop the chain on error. func (c *Chain) Optional() *Chain { if c.lastStep == nil { // Panic if no step exists to mark as optional panic("Chain.Optional: must call Step() or Call() before Optional()") } c.lastStep.optional = true return c } // WithLog adds logging attributes to the last step. func (c *Chain) WithLog(attrs ...slog.Attr) *Chain { if c.lastStep == nil { // Panic if no step exists to configure panic("Chain.WithLog: must call Step() or Call() before WithLog()") } // Append attributes to the step's logging configuration c.lastStep.config.logAttrs = append(c.lastStep.config.logAttrs, attrs...) return c } // Timeout sets a timeout for the entire chain. // Thread-safe: protected by configMu. func (c *Chain) Timeout(d time.Duration) *Chain { c.configMu.Lock() c.config.timeout = d c.configMu.Unlock() return c } // MaxErrors sets the maximum number of errors allowed. func (c *Chain) MaxErrors(max int) *Chain { if max <= 0 { c.config.maxErrors = -1 // No limit } else { c.config.maxErrors = max } return c } // With adds a key-value pair to the last step's context. func (c *Chain) With(key string, value interface{}) *Chain { if c.lastStep == nil { // Panic if no step exists to configure panic("Chain.With: must call Step() or Call() before With()") } // Initialize context map if nil if c.lastStep.config.context == nil { c.lastStep.config.context = make(map[string]interface{}) } // Add the key-value pair c.lastStep.config.context[key] = value return c } // Tag sets an error category for the last step. func (c *Chain) Tag(category ErrorCategory) *Chain { if c.lastStep == nil { // Panic if no step exists to configure panic("Chain.Tag: must call Step() or Call() before Tag()") } c.lastStep.config.category = category return c } // Code sets a numeric error code for the last step. func (c *Chain) Code(code int) *Chain { if c.lastStep == nil { // Panic if no step exists to configure panic("Chain.Code: must call Step() or Call() before Code()") } c.lastStep.config.code = code return c } // Retry configures retry behavior for the last step. // Retry configures retry behavior for the last step. func (c *Chain) Retry(maxAttempts int, delay time.Duration, opts ...RetryOption) *Chain { if c.lastStep == nil { panic("Chain.Retry: must call Step() or Call() before Retry()") } if maxAttempts < 1 { maxAttempts = 1 } // Define default retry options retryOpts := []RetryOption{ WithMaxAttempts(maxAttempts), WithDelay(delay), WithRetryIf(func(err error) bool { return IsRetryable(err) }), } // Add logging for retry attempts if a handler is configured if c.logHandler != nil { step := c.lastStep retryOpts = append(retryOpts, WithOnRetry(func(attempt int, err error) { // Prepare logging attributes logAttrs := []slog.Attr{ slog.Int("attempt", attempt), slog.Int("max_attempts", maxAttempts), } // Enhance the error with step context enhancedErr := c.enhanceError(err, step) // Log the retry attempt c.logError(enhancedErr, fmt.Sprintf("Retrying step (attempt %d/%d)", attempt, maxAttempts), step.config, logAttrs...) })) } // Append any additional retry options retryOpts = append(retryOpts, opts...) // Create and assign the retry configuration c.lastStep.config.retry = NewRetry(retryOpts...) return c } // LogOnFail enables automatic logging of errors for the last step. func (c *Chain) LogOnFail() *Chain { if c.lastStep == nil { // Panic if no step exists to configure panic("Chain.LogOnFail: must call Step() or Call() before LogOnFail()") } c.lastStep.config.logOnFail = true return c } // Run executes the chain, stopping on the first non-optional error. // It returns the first error encountered or nil if all steps succeed. func (c *Chain) Run() error { // Create a context with timeout or cancellation ctx, cancel := c.getContextAndCancel() defer cancel() c.cancel = cancel c.runCtx = ctx // share deadline with StepCtx closures // Clear any previous errors c.errors = c.errors[:0] // Execute each step in sequence for i := range c.steps { step := &c.steps[i] // Check if the context has been canceled select { case <-ctx.Done(): err := ctx.Err() // Enhance the error with step context enhancedErr := c.enhanceError(err, step) c.errors = append(c.errors, enhancedErr) // Log the context error c.logError(enhancedErr, "Chain stopped due to context error before step", step.config) return enhancedErr default: } // Execute the step err := c.executeStep(ctx, step) if err != nil { // Enhance the error with step context enhancedErr := c.enhanceError(err, step) c.errors = append(c.errors, enhancedErr) // Log the error if required if step.config.logOnFail || !step.optional { logMsg := "Chain stopped due to error in step" if step.optional { logMsg = "Optional step failed" } c.logError(enhancedErr, logMsg, step.config) } // Stop execution if the step is not optional if !step.optional { return enhancedErr } } } // Return nil if all steps completed successfully return nil } // RunAll executes all steps, collecting errors without stopping. // It returns a MultiError containing all errors or nil if none occurred. func (c *Chain) RunAll() error { ctx, cancel := c.getContextAndCancel() defer cancel() c.cancel = cancel c.runCtx = ctx // share deadline with StepCtx closures c.errors = c.errors[:0] multi := NewMultiError() for i := range c.steps { step := &c.steps[i] select { case <-ctx.Done(): err := ctx.Err() enhancedErr := c.enhanceError(err, step) c.errors = append(c.errors, enhancedErr) multi.Add(enhancedErr) c.logError(enhancedErr, "Chain stopped due to context error before step (RunAll)", step.config) goto endRunAll default: } err := c.executeStep(ctx, step) if err != nil { enhancedErr := c.enhanceError(err, step) c.errors = append(c.errors, enhancedErr) multi.Add(enhancedErr) if step.config.logOnFail && c.logHandler != nil { c.logError(enhancedErr, "Step failed during RunAll", step.config) } if c.config.maxErrors > 0 && multi.Count() >= c.config.maxErrors { if c.logHandler != nil { // Create a logger to log the max errors condition logger := slog.New(c.logHandler) logger.LogAttrs( context.Background(), slog.LevelError, fmt.Sprintf("Stopping RunAll after reaching max errors (%d)", c.config.maxErrors), slog.Int("max_errors", c.config.maxErrors), ) } goto endRunAll } } } endRunAll: return multi.Single() } // Errors returns a copy of the collected errors. func (c *Chain) Errors() []error { if len(c.errors) == 0 { return nil } // Create a copy to prevent external modification errs := make([]error, len(c.errors)) copy(errs, c.errors) return errs } // Len returns the number of steps in the chain. func (c *Chain) Len() int { return len(c.steps) } // HasErrors checks if any errors were collected. func (c *Chain) HasErrors() bool { return len(c.errors) > 0 } // LastError returns the most recent error or nil if none exist. func (c *Chain) LastError() error { if len(c.errors) > 0 { return c.errors[len(c.errors)-1] } return nil } // Reset clears the chain's steps, errors, and context. func (c *Chain) Reset() { if c.cancel != nil { // Cancel any active context c.cancel() c.cancel = nil } // Clear steps and errors c.steps = c.steps[:0] c.errors = c.errors[:0] c.lastStep = nil } // Unwrap returns the collected errors (alias for Errors). func (c *Chain) Unwrap() []error { return c.errors } // getContextAndCancel creates a context based on the chain's timeout. // It returns a context and its cancellation function. func (c *Chain) getContextAndCancel() (context.Context, context.CancelFunc) { parentCtx := context.Background() c.configMu.RLock() timeout := c.config.timeout c.configMu.RUnlock() if timeout > 0 { return context.WithTimeout(parentCtx, timeout) } return context.WithCancel(parentCtx) } // logError logs an error with step-specific context and attributes. // It only logs if a handler is configured and the error is non-nil. func (c *Chain) logError(err error, msg string, config stepConfig, additionalAttrs ...slog.Attr) { // Skip logging if no handler is set or error is nil if c == nil || c.logHandler == nil || err == nil { return } // Create a logger on demand using the configured handler logger := slog.New(c.logHandler) // Initialize attributes with error and timestamp allAttrs := make([]slog.Attr, 0, 5+len(config.logAttrs)+len(additionalAttrs)) allAttrs = append(allAttrs, slog.Any("error", err)) allAttrs = append(allAttrs, slog.Time("timestamp", time.Now())) // Add step-specific metadata if config.category != "" { allAttrs = append(allAttrs, slog.String("category", string(config.category))) } if config.code != 0 { allAttrs = append(allAttrs, slog.Int("code", config.code)) } for k, v := range config.context { allAttrs = append(allAttrs, slog.Any(k, v)) } allAttrs = append(allAttrs, config.logAttrs...) allAttrs = append(allAttrs, additionalAttrs...) // Add stack trace and error name if the error is of type *Error if e, ok := err.(*Error); ok { if stack := e.Stack(); len(stack) > 0 { // Format stack trace, truncating if too long stackStr := "\n\t" + strings.Join(stack, "\n\t") if len(stackStr) > 1000 { stackStr = stackStr[:1000] + "..." } allAttrs = append(allAttrs, slog.String("stacktrace", stackStr)) } if name := e.Name(); name != "" { allAttrs = append(allAttrs, slog.String("error_name", name)) } } // Log the error at ERROR level with all attributes // Use a defer to catch any panics during logging defer func() { if r := recover(); r != nil { // Print to stdout to avoid infinite recursion fmt.Printf("ERROR: Recovered from panic during logging: %v\nAttributes: %v\n", r, allAttrs) } }() logger.LogAttrs(context.Background(), slog.LevelError, msg, allAttrs...) } // wrapCallable wraps a function and its arguments into an executable step. // It uses reflection to validate the function and arguments. func (c *Chain) wrapCallable(fn interface{}, args ...interface{}) (func() error, error) { val := reflect.ValueOf(fn) typ := val.Type() // Ensure the provided value is a function if typ.Kind() != reflect.Func { return nil, fmt.Errorf("provided 'fn' is not a function (got %T)", fn) } // Check if the number of arguments matches the function's signature if typ.NumIn() != len(args) { return nil, fmt.Errorf("function expects %d arguments, but %d were provided", typ.NumIn(), len(args)) } // Prepare argument values argVals := make([]reflect.Value, len(args)) errorType := reflect.TypeOf((*error)(nil)).Elem() for i, arg := range args { expectedType := typ.In(i) var providedVal reflect.Value if arg != nil { providedVal = reflect.ValueOf(arg) // Check if the argument type is assignable to the expected type if !providedVal.Type().AssignableTo(expectedType) { // Special case for error interfaces if expectedType.Kind() == reflect.Interface && expectedType.Implements(errorType) && providedVal.Type().Implements(errorType) { // Allow error interface } else { return nil, fmt.Errorf("argument %d type mismatch: expected %s, got %s", i, expectedType, providedVal.Type()) } } } else { // Handle nil arguments for nullable types switch expectedType.Kind() { case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice: providedVal = reflect.Zero(expectedType) default: return nil, fmt.Errorf("argument %d is nil, but expected non-nillable type %s", i, expectedType) } } argVals[i] = providedVal } // Validate the function's return type if typ.NumOut() > 1 || (typ.NumOut() == 1 && !typ.Out(0).Implements(errorType)) { return nil, fmt.Errorf("function must return either no values or a single error (got %d return values)", typ.NumOut()) } // Return a wrapped function that calls the original with the provided arguments return func() error { results := val.Call(argVals) if len(results) == 1 && results[0].Interface() != nil { return results[0].Interface().(error) } return nil }, nil } // executeStep runs a single step, applying retries if configured. // This version is synchronous and avoids the bugs caused by the previous goroutine-based implementation. func (c *Chain) executeStep(ctx context.Context, step *chainStep) error { // First, check if the context has already been canceled before starting the step. // This allows the chain to fail fast. select { case <-ctx.Done(): return ctx.Err() default: // Context is still active, proceed. } // If the step has retry logic configured... if step.config.retry != nil { // Create a new retry instance that is aware of the chain's context. // The retry executor will be responsible for checking ctx.Done() between attempts. retryExecutor := step.config.retry.Transform(WithContext(ctx)) // Execute the step's function directly. The retry mechanism will manage the loop, // delays, and context cancellation checks. We pass step.execute without any // extra goroutine wrappers. return retryExecutor.Execute(step.execute) } // For a simple, non-retrying step, execute the function directly and synchronously // in the current goroutine. This is the simplest, fastest, and most correct approach. // It ensures that database connections are used and returned to the pool sequentially, // preventing the deadlock issue. return step.execute() } // enhanceError wraps an error with additional context from the step. func (c *Chain) enhanceError(err error, step *chainStep) error { if err == nil || !c.config.autoWrap { // Return the error unchanged if nil or autoWrap is disabled return err } // Initialize the base error var baseError *Error if e, ok := err.(*Error); ok { // Copy existing *Error to preserve its properties baseError = e.Copy() } else { // Create a new *Error wrapping the original baseError = New(err.Error()).Wrap(err).WithStack() } if step != nil { // Add step-specific context to the error if step.config.category != "" && baseError.Category() == "" { baseError.WithCategory(step.config.category) } if step.config.code != 0 && baseError.Code() == 0 { baseError.WithCode(step.config.code) } for k, v := range step.config.context { baseError.With(k, v) } for _, attr := range step.config.logAttrs { baseError.With(attr.Key, attr.Value.Any()) } if step.config.retry != nil && !baseError.HasContextKey(ctxRetry) { // Mark the error as retryable if retries are configured baseError.WithRetryable() } } return baseError } golang-github-olekukonko-errors-1.3.0/chain_test.go000066400000000000000000000603051517267734700224320ustar00rootroot00000000000000package errors import ( "context" stderrs "errors" // Alias for standard errors package to avoid conflicts "fmt" "log/slog" // Structured logging package for testing log output "strings" "testing" // Standard Go testing package "time" ) // memoryLogHandler is a custom slog handler that captures log output in memory. // It’s used to verify logging behavior in tests without writing to external systems. type memoryLogHandler struct { attrs []slog.Attr // Stores attributes for WithAttrs mu strings.Builder // Accumulates log output as a string } // NewMemoryLogHandler creates a new memoryLogHandler. // It initializes an empty handler for capturing logs. func NewMemoryLogHandler() *memoryLogHandler { return &memoryLogHandler{} } // Enabled indicates whether the handler processes logs for a given level. // Always returns true to capture all logs for testing. func (h *memoryLogHandler) Enabled(context.Context, slog.Level) bool { return true } // Handle processes a log record and formats it into the handler’s buffer. // It includes the level, message, and all attributes (including groups). func (h *memoryLogHandler) Handle(ctx context.Context, r slog.Record) error { // Write the log level and message h.mu.WriteString(fmt.Sprintf("level=%s msg=%q", r.Level, r.Message)) prefix := " " // processAttr recursively handles attributes, including nested groups var processAttr func(a slog.Attr) processAttr = func(a slog.Attr) { if a.Value.Kind() == slog.KindGroup { // Handle group attributes groupAttrs := a.Value.Group() if len(groupAttrs) > 0 { h.mu.WriteString(fmt.Sprintf("%s%s={", prefix, a.Key)) groupPrefix := "" for _, ga := range groupAttrs { h.mu.WriteString(groupPrefix) processAttr(ga) groupPrefix = " " } h.mu.WriteString("}") } } else { // Handle simple key-value attributes h.mu.WriteString(fmt.Sprintf("%s%s=%v", prefix, a.Key, a.Value.Any())) } prefix = " " } // Process handler-level attributes for _, a := range h.attrs { processAttr(a) } // Process record-level attributes r.Attrs(func(a slog.Attr) bool { processAttr(a) return true }) // Append a newline to separate log entries h.mu.WriteByte('\n') return nil } // WithAttrs creates a new handler with additional attributes. // It preserves existing attributes and appends new ones. func (h *memoryLogHandler) WithAttrs(attrs []slog.Attr) slog.Handler { newHandler := NewMemoryLogHandler() // Copy existing attributes to avoid modifying the original newHandler.attrs = append(make([]slog.Attr, 0, len(h.attrs)+len(attrs)), h.attrs...) newHandler.attrs = append(newHandler.attrs, attrs...) return newHandler } // WithGroup creates a new handler with a group attribute. // It adds a group to the attribute list for nested logging. func (h *memoryLogHandler) WithGroup(name string) slog.Handler { newHandler := NewMemoryLogHandler() // Copy existing attributes and add a new group newHandler.attrs = append(make([]slog.Attr, 0, len(h.attrs)+1), h.attrs...) newHandler.attrs = append(newHandler.attrs, slog.Group(name)) return newHandler } // GetOutput returns the accumulated log output as a string. func (h *memoryLogHandler) GetOutput() string { return h.mu.String() } // Reset clears the handler’s buffer and attributes. // It prepares the handler for a new test. func (h *memoryLogHandler) Reset() { h.mu.Reset() h.attrs = nil } // Define test errors for consistent use across tests. // These simulate various error scenarios. var ( errTest = stderrs.New("test error") // Generic test error errTemporary = stderrs.New("temporary error") // Error for retry scenarios errPermanent = stderrs.New("permanent error") // Non-retryable error errOptional = stderrs.New("optional error") // Error for optional steps errPayment = stderrs.New("payment failed") // Error for payment scenarios errStep1 = stderrs.New("error1") // First step error errStep2 = stderrs.New("error2") // Second step error errStep3 = stderrs.New("error3") // Third step error ) // TestChainExampleFromDocs tests the example usage from documentation. // It verifies retry behavior and logging for a payment processing function. func TestChainExampleFromDocs(t *testing.T) { // Initialize a memory log handler to capture logs logHandler := NewMemoryLogHandler() attempts := 0 // Track number of function executions // Define a payment processing function that fails twice before succeeding processPayment := func() error { attempts++ if attempts < 3 { // Return a retryable error for the first two attempts return New("payment failed").WithRetryable() } return nil // Succeed on the third attempt } // Create a chain with the log handler and a single step with retries c := NewChain(ChainWithLogHandler(logHandler)). Step(processPayment). Retry(3, 5*time.Millisecond) // Allow 3 total attempts with 5ms delay // Run the chain err := c.Run() // Verify no error was returned (should succeed after retries) if err != nil { t.Errorf("Expected nil error after retries, got %v", err) } // Check that exactly 3 attempts were made (initial + 2 retries) if attempts != 3 { t.Errorf("Expected 3 total attempts (initial + 2 retries), got %d", attempts) } // Get the captured log output logOutput := logHandler.GetOutput() // Log the output for debugging purposes t.Logf("Captured Log Output:\n%s", logOutput) // Verify retry log messages // Check for the first retry attempt log if !strings.Contains(logOutput, "Retrying step (attempt 1/3)") { t.Error("Missing first retry log message (expected '...attempt 1/3...')") } // Check for the second retry attempt log if !strings.Contains(logOutput, "Retrying step (attempt 2/3)") { t.Error("Missing second retry log message (expected '...attempt 2/3...')") } // Verify retry attributes in the logs if !strings.Contains(logOutput, "attempt=1 max_attempts=3") { t.Errorf("Log for first retry missing correct attributes (expected 'attempt=1 max_attempts=3')") } if !strings.Contains(logOutput, "attempt=2 max_attempts=3") { t.Errorf("Log for second retry missing correct attributes (expected 'attempt=2 max_attempts=3')") } } // TestChainBasicOperations tests basic chain functionality. // It covers empty chains, successful steps, failing steps, and optional steps. func TestChainBasicOperations(t *testing.T) { // Subtest: EmptyChain // Verifies that an empty chain runs without errors and has no steps or errors. t.Run("EmptyChain", func(t *testing.T) { c := NewChain() if err := c.Run(); err != nil { t.Errorf("Empty chain should not return error, got %v", err) } if c.Len() != 0 { t.Errorf("Empty chain should have length 0, got %d", c.Len()) } if c.HasErrors() { t.Error("Empty chain should not have errors") } }) // Subtest: SingleSuccessfulStep // Verifies that a single successful step executes and returns no error. t.Run("SingleSuccessfulStep", func(t *testing.T) { var executed bool c := NewChain().Step(func() error { executed = true; return nil }) if err := c.Run(); err != nil { t.Errorf("Single successful step should not return error, got %v", err) } if !executed { t.Error("Successful step was not executed") } }) // Subtest: SingleFailingStep // Verifies that a single failing step returns an enhanced error and records it. t.Run("SingleFailingStep", func(t *testing.T) { var executed bool c := NewChain().Step(func() error { executed = true; return errTest }) err := c.Run() if !executed { t.Error("Failing step was not executed") } // Check that the error is of the enhanced *Error type enhancedErr, ok := err.(*Error) if !ok { t.Fatalf("Expected error to be *errors.Error, got %T", err) } // Verify the error wraps the original errTest if !stderrs.Is(enhancedErr, errTest) { t.Errorf("Expected wrapped error to contain '%v', got '%v'", errTest, enhancedErr) } // Ensure the chain recorded the error if !c.HasErrors() { t.Error("Chain should have errors after failure") } }) // Subtest: MultipleStepsWithFailure // Verifies that execution stops after a non-optional failure and only prior steps run. t.Run("MultipleStepsWithFailure", func(t *testing.T) { var step1, step3 bool c := NewChain(). Step(func() error { step1 = true; return nil }). Step(func() error { return errTest }). Step(func() error { step3 = true; return nil }) err := c.Run() if !step1 { t.Error("Step 1 should have executed") } if step3 { t.Error("Step 3 should not have executed after failure") } // Verify the error is enhanced enhancedErr, ok := err.(*Error) if !ok { t.Fatalf("Expected error to be *errors.Error, got %T", err) } // Verify the error wraps errTest if !stderrs.Is(enhancedErr, errTest) { t.Errorf("Expected wrapped error '%v', got '%v'", errTest, enhancedErr) } }) // Subtest: OptionalStepFailure // Verifies that an optional failing step doesn’t stop execution and Run() returns nil. t.Run("OptionalStepFailure", func(t *testing.T) { var step1, step3 bool c := NewChain(). Step(func() error { step1 = true; return nil }). Step(func() error { return errOptional }).Optional(). Step(func() error { step3 = true; return nil }) err := c.Run() if !step1 { t.Error("Step 1 should have executed") } if !step3 { t.Error("Step 3 should have executed after optional failure") } if err != nil { t.Errorf("Run() should return nil when only optional steps fail, got %v", err) } if !c.HasErrors() { t.Error("Chain should have errors even if only optional failed") } }) // Subtest: OptionalStepSuccess // Verifies that all steps, including optional successful ones, execute correctly. t.Run("OptionalStepSuccess", func(t *testing.T) { var step1, step2, step3 bool c := NewChain(). Step(func() error { step1 = true; return nil }). Step(func() error { step2 = true; return nil }).Optional(). Step(func() error { step3 = true; return nil }) if err := c.Run(); err != nil { t.Errorf("Unexpected error: %v", err) } if !step1 || !step2 || !step3 { t.Error("All steps should have executed") } }) } // TestChainErrorEnhancement tests error wrapping and metadata enhancement. // It verifies auto-wrapping, disabling wrapping, and adding metadata. func TestChainErrorEnhancement(t *testing.T) { // Subtest: AutoWrapStandardErrors // Verifies that standard errors are automatically wrapped with stack traces. t.Run("AutoWrapStandardErrors", func(t *testing.T) { stdErr := fmt.Errorf("standard error %d", 123) c := NewChain().Step(func() error { return stdErr }) err := c.Run() // Verify the error is enhanced enhancedErr, ok := err.(*Error) if !ok { t.Fatalf("Expected error to be *errors.Error, got %T", err) } // Check that it wraps the original error if !stderrs.Is(enhancedErr, stdErr) { t.Errorf("Wrapped error should contain '%v', got '%v'", stdErr, enhancedErr) } // Ensure a stack trace was added if len(enhancedErr.Stack()) == 0 { t.Error("Enhanced error should have a stack trace") } }) // Subtest: DisableAutoWrap // Verifies that disabling auto-wrapping returns the raw error. t.Run("DisableAutoWrap", func(t *testing.T) { stdErr := fmt.Errorf("standard error %d", 456) c := NewChain(ChainWithAutoWrap(false)).Step(func() error { return stdErr }) err := c.Run() // Ensure the error is not wrapped if _, ok := err.(*Error); ok { t.Fatalf("Error should not be wrapped when ChainWithAutoWrap(false), got *errors.Error") } // Verify it’s the original error if !stderrs.Is(err, stdErr) { t.Errorf("Expected raw error '%v', got '%v'", stdErr, err) } }) // Subtest: ErrorMetadataViaEnhancement // Verifies that metadata (context, category, code, log attributes) is added to errors. t.Run("ErrorMetadataViaEnhancement", func(t *testing.T) { // Define metadata category := ErrorCategory("database") code := 503 key := "query_id" value := "xyz789" logKey := "trace_id" logValue := "trace-abc" // Create a chain with a failing step and metadata c := NewChain(). Step(func() error { return errTest }). With(key, value).Tag(category).Code(code). WithLog(slog.String(logKey, logValue)) err := c.Run() if err == nil { t.Fatal("Expected an error") } // Verify the error is enhanced enhancedErr, ok := err.(*Error) if !ok { t.Fatalf("Expected error to be *errors.Error, got %T", err) } // Check context metadata contextMap := enhancedErr.Context() if val, ok := contextMap[key]; !ok || val != value { t.Errorf("Expected context['%s'] == %v, got %v", key, value, val) } if val, ok := contextMap[logKey]; !ok || val != logValue { t.Errorf("Expected context['%s'] == %v, got %v", logKey, logValue, val) } // Check category if enhancedErr.Category() != string(category) { t.Errorf("Expected category %q, got %q", category, enhancedErr.Category()) } // Check error code if enhancedErr.Code() != code { t.Errorf("Expected code %d, got %d", code, enhancedErr.Code()) } }) } // TestChainRetryLogic tests retry behavior for different scenarios. // It verifies successful retries, failed retries, and context timeout interactions. func TestChainRetryLogic(t *testing.T) { // Define errors for the test errTemporary := New("temporary error").WithRetryable() // Retryable error errPermanent := stderrs.New("permanent error") // Non-retryable error // Subtest: RetrySuccessful // Verifies that a retryable error eventually succeeds after retries. t.Run("RetrySuccessful", func(t *testing.T) { attempts := 0 logHandler := NewMemoryLogHandler() c := NewChain(ChainWithLogHandler(logHandler)). Step(func() error { attempts++ t.Logf("RetrySuccessful: Attempt %d", attempts) if attempts < 3 { return errTemporary // Fails for first two attempts } return nil // Succeeds on third attempt }). Retry(3, 1*time.Millisecond) // Allow 3 attempts err := c.Run() if err != nil { t.Errorf("Expected success after retries, got %v", err) } // Verify exactly 3 attempts were made if attempts != 3 { t.Errorf("Expected 3 attempts (initial + 2 retries), got %d", attempts) } }) // Subtest: RetryFailure // Verifies that a non-retryable error fails after forced retries. t.Run("RetryFailure", func(t *testing.T) { attempts := 0 logHandler := NewMemoryLogHandler() c := NewChain(ChainWithLogHandler(logHandler)). Step(func() error { attempts++ t.Logf("RetryFailure: Attempt %d", attempts) return errPermanent // Always fails }). // Force retries even for non-retryable errors Retry(3, 1*time.Millisecond, WithRetryIf(func(error) bool { return true })) err := c.Run() if err == nil { t.Error("Expected failure after retries") } // Verify all attempts were made if attempts != 3 { t.Errorf("Expected 3 attempts (initial + 2 retries), got %d", attempts) } // Verify the error is enhanced enhancedErr, ok := err.(*Error) if !ok { t.Fatalf("Expected enhanced *Error, got %T", err) } // Check that it wraps the original error if !stderrs.Is(enhancedErr, errPermanent) { t.Errorf("Expected enhanced error wrapping '%v', got '%v'", errPermanent, enhancedErr) } }) // Subtest: RetryRespectsContext // Verifies that retries respect the chain’s timeout. t.Run("RetryRespectsContext", func(t *testing.T) { attempts := 0 logHandler := NewMemoryLogHandler() c := NewChain(ChainWithLogHandler(logHandler), ChainWithTimeout(10*time.Millisecond)). Step(func() error { attempts++ t.Logf("RetryRespectsContext: Attempt %d starting sleep", attempts) // Sleep longer than the timeout to trigger context cancellation time.Sleep(25 * time.Millisecond) t.Logf("RetryRespectsContext: Attempt %d finished sleep (should not happen)", attempts) return errPermanent }). // Force retries to ensure timeout is the limiting factor Retry(2, 5*time.Millisecond, WithRetryIf(func(error) bool { return true })) err := c.Run() if err == nil { t.Fatal("Expected an error due to timeout") } // Verify the error is due to timeout if !stderrs.Is(err, context.DeadlineExceeded) { t.Errorf("Expected error wrapping context.DeadlineExceeded, got %v (type %T)", err, err) } // Expect only one attempt due to timeout if attempts != 1 { t.Errorf("Expected exactly 1 attempt before timeout, got %d", attempts) } }) } // TestChainContext tests context-related behavior, specifically timeouts. // It verifies that timeouts stop execution as expected. func TestChainContext(t *testing.T) { // Subtest: TimeoutStopsExecution // Verifies that a chain timeout prevents subsequent steps from running. t.Run("TimeoutStopsExecution", func(t *testing.T) { var step1Started, step2Executed bool c := NewChain(ChainWithTimeout(20 * time.Millisecond)). Step(func() error { step1Started = true // Sleep longer than the timeout time.Sleep(50 * time.Millisecond) return nil }). Step(func() error { step2Executed = true return nil }) err := c.Run() if err == nil { t.Fatal("Expected an error") } // Verify the error is due to timeout if !stderrs.Is(err, context.DeadlineExceeded) { t.Errorf("Expected context.DeadlineExceeded wrapped, got %v", err) } if !step1Started { t.Error("Step 1 should have started execution") } if step2Executed { t.Error("Step 2 should not have executed after timeout") } }) } // TestChainLogging tests logging behavior for failing steps. // It verifies log messages and attributes for optional and non-optional steps. func TestChainLogging(t *testing.T) { logHandler := NewMemoryLogHandler() // Subtest: LogOnFail_NonOptional // Verifies that a non-optional failing step logs with all metadata. t.Run("LogOnFail_NonOptional", func(t *testing.T) { logHandler.Reset() category := ErrorCategory("test_cat") c := NewChain(ChainWithLogHandler(logHandler)). Step(func() error { return errTest }).LogOnFail(). With("key", "value").Tag(category).Code(500) err := c.Run() if err == nil { t.Fatal("Expected error") } logOutput := logHandler.GetOutput() // Verify error message in logs if !strings.Contains(logOutput, "test error") { t.Errorf("Log missing 'error=test error' attribute. Got: %s", logOutput) } // Verify category if !strings.Contains(logOutput, "category=test_cat") { t.Errorf("Log missing 'category=test_cat' attribute. Got: %s", logOutput) } // Verify error code if !strings.Contains(logOutput, "code=500") { t.Errorf("Log missing 'code=500' attribute. Got: %s", logOutput) } // Verify context metadata if !strings.Contains(logOutput, "key=value") { t.Errorf("Log missing 'key=value' attribute. Got: %s", logOutput) } // Verify log message if !strings.Contains(logOutput, "Chain stopped due to error in step") { t.Errorf("Log missing correct message. Got: %s", logOutput) } }) // Subtest: LogOnFail_Optional // Verifies that an optional failing step logs correctly when configured. t.Run("LogOnFail_Optional", func(t *testing.T) { logHandler.Reset() category := ErrorCategory("opt_cat") c := NewChain(ChainWithLogHandler(logHandler)). Step(func() error { return errOptional }).Optional().LogOnFail(). With("optKey", "optValue").Tag(category) err := c.Run() if err != nil { t.Fatalf("Run should succeed when only optional fails, got: %v", err) } logOutput := logHandler.GetOutput() // Verify log message for optional failure if !strings.Contains(logOutput, "Optional step failed") { t.Errorf("Log should contain 'Optional step failed' message: %s", logOutput) } // Verify error message if !strings.Contains(logOutput, "error=optional error") { t.Errorf("Log should contain 'error=optional error': %s", logOutput) } // Verify category if !strings.Contains(logOutput, "category=opt_cat") { t.Errorf("Log missing 'category=opt_cat': %s", logOutput) } // Verify context metadata if !strings.Contains(logOutput, "optKey=optValue") { t.Errorf("Log missing 'optKey=optValue': %s", logOutput) } }) // Subtest: NoLogOnFail_Optional // Verifies that an optional failing step doesn’t log without LogOnFail. t.Run("NoLogOnFail_Optional", func(t *testing.T) { logHandler.Reset() c := NewChain(ChainWithLogHandler(logHandler)). Step(func() error { return errOptional }).Optional() err := c.Run() if err != nil { t.Fatalf("Run should succeed when only optional fails, got: %v", err) } logOutput := logHandler.GetOutput() if logOutput != "" { t.Errorf("Expected no log output without LogOnFail, got: %s", logOutput) } }) } // TestChainRunAll tests the RunAll method. // It verifies error collection and max error limits. func TestChainRunAll(t *testing.T) { // Subtest: CollectAllErrors // Verifies that RunAll collects all errors and executes all steps. t.Run("CollectAllErrors", func(t *testing.T) { var step2Executed bool c := NewChain(). Step(func() error { return errStep1 }). Step(func() error { step2Executed = true; return nil }).Optional(). Step(func() error { return errStep2 }) err := c.RunAll() if !step2Executed { t.Error("Optional successful step should have executed in RunAll") } // Verify the error is a MultiError multiErr, ok := err.(*MultiError) if !ok { t.Fatalf("Expected *MultiError, got %T", err) } // Check that exactly two errors were collected if len(multiErr.Errors()) != 2 { t.Errorf("Expected 2 errors collected in RunAll, got %d", len(multiErr.Errors())) } }) // Subtest: RunAllWithMaxErrors // Verifies that RunAll stops after reaching the max error limit. t.Run("RunAllWithMaxErrors", func(t *testing.T) { var step3Executed bool c := NewChain(ChainWithMaxErrors(2)). Step(func() error { return errStep1 }). Step(func() error { return errStep2 }). Step(func() error { step3Executed = true; return errStep3 }) err := c.RunAll() if step3Executed { t.Error("Step 3 should not have executed after MaxErrors limit") } // Verify the error is a MultiError multiErr, ok := err.(*MultiError) if !ok { t.Fatalf("Expected MultiError, got %T", err) } // Check that only two errors were collected due to the limit if len(multiErr.Errors()) != 2 { t.Errorf("Expected exactly 2 errors due to max limit, got %d", len(multiErr.Errors())) } }) } // TestChainReset tests the Reset method. // It verifies that the chain is fully cleared. func TestChainReset(t *testing.T) { // Create a chain with a step, timeout, and metadata c := NewChain(ChainWithTimeout(1*time.Second)). Step(func() error { return errTest }).With("key", "value") _ = c.Run() // Reset the chain c.Reset() // Verify the chain is empty if c.Len() != 0 { t.Errorf("Reset chain should have 0 steps, got %d", c.Len()) } if c.HasErrors() { t.Errorf("Reset chain should have 0 errors, got %v", c.Errors()) } if c.lastStep != nil { t.Error("Reset chain should have nil lastStep") } } // TestChainReflectionCall tests the Call method with reflection. // It verifies that functions with arguments are handled correctly. func TestChainReflectionCall(t *testing.T) { // Subtest: CallWithArgsFailure // Verifies that a function with arguments returns an enhanced error. t.Run("CallWithArgsFailure", func(t *testing.T) { internalErr := fmt.Errorf("failure with %d", 10) fn := func(a int) error { return internalErr } c := NewChain().Call(fn, 10) err := c.Run() if err == nil { t.Fatal("Expected error from Call") } // Verify the error is enhanced enhancedErr, ok := err.(*Error) if !ok { t.Fatalf("Expected wrapped *errors.Error, got %T", err) } // Check that it wraps the original error if !stderrs.Is(enhancedErr, internalErr) { t.Errorf("Expected enhanced error to wrap '%v', got '%v'", internalErr, enhancedErr) } }) } // TestChainErrorInspection tests error inspection methods. // It verifies LastError and Errors after execution. func TestChainErrorInspection(t *testing.T) { // Create a chain with two failing steps c := NewChain(). Step(func() error { return errStep1 }). Step(func() error { return errStep2 }) _ = c.RunAll() // Verify the last error lastErr := c.LastError() if lastErr == nil { t.Fatal("LastError should not be nil after RunAll") } if !stderrs.Is(lastErr, errStep2) { t.Errorf("LastError should wrap %v, got %v", errStep2, lastErr) } // Verify the number of collected errors if len(c.Errors()) != 2 { t.Errorf("Expected 2 errors collected, got %d", len(c.Errors())) } } golang-github-olekukonko-errors-1.3.0/chan.go000066400000000000000000000170131517267734700212200ustar00rootroot00000000000000// Channel-based error utilities and streaming error collection. // All functions compose with the standard (chan T, chan error) idiom. package errors import ( "context" "fmt" "sync" "sync/atomic" ) // ErrLimitReached is returned by Collect when n errors have been gathered // before the channel closed or the context was done. // Callers can use errors.Is(err, ErrLimitReached) to distinguish this case. var ErrLimitReached = Const("limit_reached", "error collection limit reached") // Drain reads all errors from ch until it is closed and returns them as a // *MultiError. Returns nil if every received value was nil. // Blocks until ch is closed. // // Example: // // results, errs := processItems(ctx, items) // if err := errors.Drain(errs); err != nil { // log.Println(err) // } func Drain(ch <-chan error) error { m := NewMultiError() for err := range ch { if err != nil { m.Add(err) } } return m.Single() } // First returns the first non-nil error received from ch, then returns. // Uses ctx for deadline/cancellation only — it does NOT call any cancel // function. The caller is responsible for cancelling sibling work after // First returns. // // Returns nil if ch is closed before any error arrives, or ctx is done. // Returns ctx.Err() if the context is cancelled or times out. // // Example: // // ctx, cancel := context.WithCancel(context.Background()) // results, errs := processItems(ctx, items) // if err := errors.First(ctx, errs); err != nil { // cancel() // caller decides to stop siblings // log.Println(err) // } // defer cancel() func First(ctx context.Context, ch <-chan error) error { for { select { case err, ok := <-ch: if !ok { return nil } if err != nil { return err } case <-ctx.Done(): return ctx.Err() } } } // Fan merges multiple error channels into a single output channel that closes // when all inputs have closed or ctx is done. // // Callers MUST either drain the returned channel to completion OR cancel ctx — // failing to do so leaks the internal goroutines. The select in each forwarder // respects ctx.Done() so cancellation is always safe. // // Example: // // errs1, errs2 := stage1(ctx), stage2(ctx) // for err := range errors.Fan(ctx, errs1, errs2) { // if err != nil { log.Println(err) } // } func Fan(ctx context.Context, chans ...<-chan error) <-chan error { bufSize := len(chans) if bufSize < 1 { bufSize = 1 } out := make(chan error, bufSize) var wg sync.WaitGroup wg.Add(len(chans)) for _, ch := range chans { ch := ch go func() { defer wg.Done() for { select { case err, ok := <-ch: if !ok { return } select { case out <- err: case <-ctx.Done(): return } case <-ctx.Done(): return } } }() } go func() { wg.Wait() close(out) }() return out } // Collect reads up to n non-nil errors from ch (or until ch closes or ctx is // done) and returns them as a *MultiError. // // If the limit n is reached before the channel closes, the returned error // wraps ErrLimitReached as its cause so callers can distinguish the two cases: // // err := errors.Collect(ctx, errs, 10) // if errors.Is(err, errors.ErrLimitReached) { // // stopped early — more errors may exist // } func Collect(ctx context.Context, ch <-chan error, n int) error { m := NewMultiError(WithLimit(n)) for { select { case err, ok := <-ch: if !ok { return m.Single() } if err != nil { m.Add(err) if m.Count() >= n { // Wrap in an *Error so errors.Is traversal finds ErrLimitReached. e := New(fmt.Sprintf("collected %d errors (limit reached)", n)) e.cause = ErrLimitReached if inner := m.Single(); inner != nil { return New(e.Error()).Wrap(inner).Wrap(ErrLimitReached) } return e } } case <-ctx.Done(): return m.Single() } } } // Stream — concurrent item processing with progressive error collection // Stream processes a slice of items concurrently, collecting errors as they // occur without stopping execution. Use Wait() to block until all items are // done, or Each() to process errors as they arrive. // // Example — collect all: // // s := errors.NewStream(ctx, items, process, 8) // if err := s.Wait(); err != nil { // log.Println(err) // } // // Example — process as they arrive: // // s := errors.NewStream(ctx, items, process, 8) // s.Each(func(err error) { log.Println(err) }) type Stream[T any] struct { ch chan error done chan struct{} stopCh chan struct{} closeOnce sync.Once stopOnce sync.Once // consumed guards Each/Wait — only one consumer is permitted. // 0 = available, 1 = consumed. Enforced with atomic CAS. consumed atomic.Int32 } // NewStream creates a Stream that applies fn to every item in items using // up to workers concurrent goroutines. // // workers <= 0 defaults to len(items), running all items at once. // Respects ctx: in-flight work completes but no new items start once ctx // is done. // // Example: // // s := errors.NewStream(ctx, urls, func(url string) error { // return fetch(url) // }, 8) func NewStream[T any](ctx context.Context, items []T, fn func(T) error, workers ...int) *Stream[T] { w := len(items) if len(workers) > 0 && workers[0] > 0 { w = workers[0] } if w < 1 { w = 1 } s := &Stream[T]{ ch: make(chan error, w), done: make(chan struct{}), stopCh: make(chan struct{}), } go s.run(ctx, items, fn, w) return s } func (s *Stream[T]) run(ctx context.Context, items []T, fn func(T) error, workers int) { defer func() { s.closeOnce.Do(func() { close(s.ch) }) close(s.done) }() work := make(chan T, workers) var wg sync.WaitGroup for i := 0; i < workers; i++ { wg.Add(1) go func() { defer wg.Done() for item := range work { if err := fn(item); err != nil { select { case s.ch <- err: case <-s.stopCh: return case <-ctx.Done(): return } } } }() } feed: for _, item := range items { select { case work <- item: case <-s.stopCh: break feed case <-ctx.Done(): break feed } } close(work) wg.Wait() } // acquireConsumer atomically marks the stream as consumed. // Panics if called more than once — Each and Wait are mutually exclusive. func (s *Stream[T]) acquireConsumer(name string) { if !s.consumed.CompareAndSwap(0, 1) { panic(fmt.Sprintf("errors.Stream: %s called on an already-consumed Stream; Each and Wait are mutually exclusive", name)) } } // Each calls fn for every error produced by the stream, in the order they // arrive. Blocks until all items have been processed. // // Panics if called after Wait (or a second call to Each). func (s *Stream[T]) Each(fn func(error)) { s.acquireConsumer("Each") for err := range s.ch { fn(err) } <-s.done } // Wait blocks until all items have been processed and returns a *MultiError // containing every error collected, or nil if all items succeeded. // // Panics if called after Each (or a second call to Wait). func (s *Stream[T]) Wait() error { s.acquireConsumer("Wait") m := NewMultiError() for err := range s.ch { m.Add(err) } <-s.done return m.Single() } // Stop signals the stream to stop processing new items and drains any // buffered errors in the background. Safe to call multiple times. // After Stop, Wait and Each will still return promptly but may not see // all errors. func (s *Stream[T]) Stop() { s.stopOnce.Do(func() { close(s.stopCh) // Drain the error channel in the background so run() goroutines // are not blocked trying to send, preventing a goroutine leak. go func() { for range s.ch { } }() }) } golang-github-olekukonko-errors-1.3.0/chan_test.go000066400000000000000000000216461517267734700222660ustar00rootroot00000000000000package errors import ( "context" "fmt" "strings" "sync/atomic" "testing" "time" ) // Drain func TestDrainAllNil(t *testing.T) { ch := make(chan error, 3) ch <- nil ch <- nil ch <- nil close(ch) if err := Drain(ch); err != nil { t.Errorf("Drain(all nil) = %v, want nil", err) } } func TestDrainCollectsErrors(t *testing.T) { ch := make(chan error, 3) ch <- New("one") ch <- nil ch <- New("two") close(ch) err := Drain(ch) if err == nil { t.Fatal("Drain() = nil, want errors") } multi, ok := err.(*MultiError) if !ok { t.Fatalf("Drain() type = %T, want *MultiError", err) } if multi.Count() != 2 { t.Errorf("Drain() count = %d, want 2", multi.Count()) } } func TestDrainEmpty(t *testing.T) { ch := make(chan error) close(ch) if err := Drain(ch); err != nil { t.Errorf("Drain(empty) = %v, want nil", err) } } func TestDrainSingleError(t *testing.T) { ch := make(chan error, 1) ch <- New("only error") close(ch) err := Drain(ch) if err == nil { t.Fatal("expected error, got nil") } if !strings.Contains(err.Error(), "only error") { t.Errorf("unexpected message: %q", err.Error()) } } // First func TestFirstReturnsFirstError(t *testing.T) { ch := make(chan error, 3) ch <- nil ch <- New("first real error") ch <- New("second error") close(ch) err := First(context.Background(), ch) if err == nil { t.Fatal("First() = nil, want error") } if !strings.Contains(err.Error(), "first real error") { t.Errorf("First() = %q, want 'first real error'", err.Error()) } } func TestFirstChannelClosedNoError(t *testing.T) { ch := make(chan error, 2) ch <- nil ch <- nil close(ch) if err := First(context.Background(), ch); err != nil { t.Errorf("First(no errors) = %v, want nil", err) } } func TestFirstContextCancelledReturnsCtxErr(t *testing.T) { ch := make(chan error) // never sends ctx, cancel := context.WithCancel(context.Background()) cancel() err := First(ctx, ch) if err != context.Canceled { t.Errorf("First(cancelled) = %v, want context.Canceled", err) } } func TestFirstCallerOwnsCancel(t *testing.T) { // Verify the documented pattern: First returns, caller cancels siblings. ctx, cancel := context.WithCancel(context.Background()) defer cancel() ch := make(chan error, 2) ch <- New("fail") ch <- New("also fail") close(ch) err := First(ctx, ch) if err == nil { t.Fatal("expected error") } // Caller now calls cancel() — this is the correct usage cancel() // ctx should now be done select { case <-ctx.Done(): default: t.Error("ctx should be cancelled after caller calls cancel()") } } func TestFirstContextDeadline(t *testing.T) { ch := make(chan error) // never sends ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) defer cancel() err := First(ctx, ch) if err != context.DeadlineExceeded { t.Errorf("First(deadline) = %v, want DeadlineExceeded", err) } } // Collect func TestCollectUpToN(t *testing.T) { ch := make(chan error, 10) for i := 0; i < 10; i++ { ch <- fmt.Errorf("error %d", i) } close(ch) err := Collect(context.Background(), ch, 3) if err == nil { t.Fatal("Collect() = nil, want errors") } if !Is(err, ErrLimitReached) { t.Errorf("Collect(limit) should wrap ErrLimitReached, got: %v", err) } } func TestCollectFewerThanN(t *testing.T) { ch := make(chan error, 3) ch <- New("a") ch <- New("b") close(ch) err := Collect(context.Background(), ch, 10) if err == nil { t.Fatal("expected errors") } // Did not hit limit — should NOT wrap ErrLimitReached if Is(err, ErrLimitReached) { t.Error("Collect(under limit) should not wrap ErrLimitReached") } } func TestCollectAllNil(t *testing.T) { ch := make(chan error, 3) ch <- nil ch <- nil close(ch) if err := Collect(context.Background(), ch, 5); err != nil { t.Errorf("Collect(all nil) = %v, want nil", err) } } func TestCollectContextDone(t *testing.T) { ch := make(chan error) // blocks forever ctx, cancel := context.WithCancel(context.Background()) cancel() err := Collect(ctx, ch, 10) if err != nil { t.Errorf("Collect(cancelled) = %v, want nil", err) } } // Fan func TestFanMergesChannels(t *testing.T) { ch1 := make(chan error, 2) ch2 := make(chan error, 2) ch1 <- New("ch1-a") ch1 <- New("ch1-b") close(ch1) ch2 <- New("ch2-a") close(ch2) var collected []error for err := range Fan(context.Background(), ch1, ch2) { if err != nil { collected = append(collected, err) } } if len(collected) != 3 { t.Errorf("Fan() collected %d errors, want 3", len(collected)) } } func TestFanEmpty(t *testing.T) { ch1 := make(chan error) ch2 := make(chan error) close(ch1) close(ch2) var count int for range Fan(context.Background(), ch1, ch2) { count++ } if count != 0 { t.Errorf("Fan(empty inputs) received %d items, want 0", count) } } func TestFanNoInputs(t *testing.T) { merged := Fan(context.Background()) select { case _, ok := <-merged: if ok { t.Error("Fan() with no inputs sent a value before closing") } case <-time.After(100 * time.Millisecond): t.Error("Fan() with no inputs did not close promptly") } } func TestFanContextCancellation(t *testing.T) { ch := make(chan error) // never closes ctx, cancel := context.WithCancel(context.Background()) merged := Fan(ctx, ch) cancel() select { case <-merged: // closed or received — either is fine case <-time.After(200 * time.Millisecond): t.Error("Fan() did not close after ctx cancellation") } } // Stream func TestStreamAllSucceed(t *testing.T) { items := []int{1, 2, 3, 4, 5} s := NewStream(context.Background(), items, func(n int) error { return nil }) if err := s.Wait(); err != nil { t.Errorf("Stream(all succeed) = %v, want nil", err) } } func TestStreamCollectsAllErrors(t *testing.T) { items := []int{1, 2, 3, 4, 5} s := NewStream(context.Background(), items, func(n int) error { if n%2 == 0 { return fmt.Errorf("error %d", n) } return nil }) err := s.Wait() if err == nil { t.Fatal("Stream() = nil, want errors") } multi, ok := err.(*MultiError) if !ok { t.Fatalf("Stream() type = %T, want *MultiError", err) } if multi.Count() != 2 { t.Errorf("Stream() count = %d, want 2", multi.Count()) } } func TestStreamEach(t *testing.T) { items := []string{"a", "b", "c"} s := NewStream(context.Background(), items, func(item string) error { if item == "b" { return New("b failed") } return nil }) var count int s.Each(func(err error) { count++ }) if count != 1 { t.Errorf("Stream.Each() called fn %d times, want 1", count) } } func TestStreamDoubleConsumePanics(t *testing.T) { s := NewStream(context.Background(), []int{1}, func(n int) error { return nil }) _ = s.Wait() defer func() { if r := recover(); r == nil { t.Error("second Wait() should panic") } }() _ = s.Wait() } func TestStreamEachThenWaitPanics(t *testing.T) { items := []int{1, 2} s := NewStream(context.Background(), items, func(n int) error { return nil }) s.Each(func(err error) {}) defer func() { if r := recover(); r == nil { t.Error("Wait() after Each() should panic") } }() _ = s.Wait() } func TestStreamWorkerConcurrency(t *testing.T) { var concurrent atomic.Int32 var maxConcurrent atomic.Int32 items := make([]int, 20) s := NewStream(context.Background(), items, func(n int) error { c := concurrent.Add(1) for { cur := maxConcurrent.Load() if c <= cur || maxConcurrent.CompareAndSwap(cur, c) { break } } time.Sleep(5 * time.Millisecond) concurrent.Add(-1) return nil }, 4) _ = s.Wait() if maxConcurrent.Load() > 4 { t.Errorf("concurrency exceeded workers: max=%d, want<=4", maxConcurrent.Load()) } } func TestStreamContextCancellation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() items := make([]int, 100) var processed atomic.Int32 s := NewStream(ctx, items, func(n int) error { processed.Add(1) if processed.Load() == 5 { cancel() } time.Sleep(time.Millisecond) return nil }, 2) _ = s.Wait() if processed.Load() == 100 { t.Error("Stream did not respect context cancellation") } } func TestStreamStop(t *testing.T) { items := make([]int, 100) var processed atomic.Int32 s := NewStream(context.Background(), items, func(n int) error { processed.Add(1) time.Sleep(time.Millisecond) return nil }, 2) time.Sleep(10 * time.Millisecond) s.Stop() // After Stop, goroutines should not leak — channel is drained by Stop. // Give it time to settle. time.Sleep(20 * time.Millisecond) if processed.Load() == 100 { t.Error("Stream.Stop() did not stop processing early") } } func TestStreamEmpty(t *testing.T) { s := NewStream(context.Background(), []string{}, func(s string) error { return New("should not be called") }) if err := s.Wait(); err != nil { t.Errorf("Stream(empty items) = %v, want nil", err) } } func TestStreamDefaultWorkers(t *testing.T) { s := NewStream(context.Background(), []int{1, 2, 3}, func(n int) error { return nil }) if err := s.Wait(); err != nil { t.Errorf("Stream(default workers) = %v, want nil", err) } } golang-github-olekukonko-errors-1.3.0/errmgr/000077500000000000000000000000001517267734700212545ustar00rootroot00000000000000golang-github-olekukonko-errors-1.3.0/errmgr/common.go000066400000000000000000000210451517267734700230750ustar00rootroot00000000000000// Package errmgr provides common error definitions and categories for use across applications. // These predefined errors are designed for consistency in error handling and can be used // directly as immutable instances or copied for customization using Copy(). package errmgr import ( "github.com/olekukonko/errors" ) // Common error categories used for organizing errors across different domains. const ( CategoryAuth errors.ErrorCategory = "auth" // Authentication-related errors (e.g., login failures) CategoryBusiness errors.ErrorCategory = "business" // Business logic errors (e.g., rule violations) CategoryDatabase errors.ErrorCategory = "database" // Database-related errors (e.g., connection issues) CategoryIO errors.ErrorCategory = "io" // Input/Output-related errors (e.g., file operations) CategoryNetwork errors.ErrorCategory = "network" // Network-related errors (e.g., timeouts, unreachable hosts) CategorySystem errors.ErrorCategory = "system" // System-level errors (e.g., resource exhaustion) CategoryUser errors.ErrorCategory = "user" // User-related errors (e.g., invalid input, permissions) CategoryValidation errors.ErrorCategory = "validation" // Validation-related errors (e.g., invalid input formats) ) // Common HTTP status codes used for error responses, aligned with REST API conventions. const ( CodeBadRequest = 400 // HTTP 400 Bad Request (client error, invalid input) CodeUnauthorized = 401 // HTTP 401 Unauthorized (authentication required) CodeForbidden = 403 // HTTP 403 Forbidden (access denied) CodeNotFound = 404 // HTTP 404 Not Found (resource not found) CodeMethodNotAllowed = 405 // HTTP 405 Method Not Allowed (unsupported method) CodeConflict = 409 // HTTP 409 Conflict (resource conflict) CodeUnprocessable = 422 // HTTP 422 Unprocessable Entity (semantic errors in request) CodeTooManyRequests = 429 // HTTP 429 Too Many Requests (rate limiting) CodeInternalError = 500 // HTTP 500 Internal Server Error (server failure) CodeNotImplemented = 501 // HTTP 501 Not Implemented (feature not supported) CodeServiceUnavailable = 503 // HTTP 503 Service Unavailable (temporary unavailability) ) // Generic Predefined Errors (Static) // These are immutable instances suitable for direct use or copying with Copy(). // Errors requiring specific properties like WithRetryable() or WithTimeout() are defined here. var ( ErrInvalidArg = errors.New("invalid argument").WithCode(CodeBadRequest) ErrNotFound = errors.New("not found").WithCode(CodeNotFound) ErrPermission = errors.New("permission denied").WithCode(CodeForbidden) ErrTimeout = errors.New("operation timed out").WithTimeout() ErrUnknown = errors.New("unknown error").WithCode(CodeInternalError) ErrDBConnRetryable = errors.New("database connection failed").WithCategory(CategoryDatabase).WithRetryable() ErrNetworkRetryable = errors.New("network failure").WithCategory(CategoryNetwork).WithRetryable() ErrNetworkTimedOut = errors.New("network timeout").WithCategory(CategoryNetwork).WithTimeout().WithRetryable() ErrServiceRetryable = errors.New("service unavailable").WithCode(CodeServiceUnavailable).WithRetryable() ErrRateLimitRetryable = errors.New("rate limit exceeded").WithCode(CodeTooManyRequests).WithRetryable() ) // Authentication Errors (Templated) // Use these by providing arguments, e.g., ErrAuthFailed("user@example.com", "invalid password"). var ( ErrAuthFailed = Coded("ErrAuthFailed", "authentication failed for %s: %s", CodeUnauthorized) ErrInvalidToken = Coded("ErrInvalidToken", "invalid authentication token: %s", CodeUnauthorized) ErrMissingCreds = Coded("ErrMissingCreds", "missing credentials: %s", CodeBadRequest) ErrTokenExpired = Coded("ErrTokenExpired", "authentication token expired: %s", CodeUnauthorized) ) // Business Logic Errors (Templated) // Example: ErrInsufficientFunds("account123", "balance too low"). var ( ErrBusinessRule = Categorized(CategoryBusiness, "ErrBusinessRule", "business rule violation: %s") ErrInsufficientFunds = Categorized(CategoryBusiness, "ErrInsufficientFunds", "insufficient funds: %s") ) // Database Errors (Templated) // Example: ErrDBConnection("mysql", "host unreachable"). var ( ErrDBConnection = Categorized(CategoryDatabase, "ErrDBConnection", "database connection failed: %s") ErrDBConstraint = Coded("ErrDBConstraint", "database constraint violation: %s", CodeConflict) ErrDBQuery = Categorized(CategoryDatabase, "ErrDBQuery", "database query failed: %s") ErrDBTimeout = Categorized(CategoryDatabase, "ErrDBTimeout", "database operation timed out: %s") ) // IO Errors (Templated) // Example: ErrFileNotFound("/path/to/file"). var ( ErrFileNotFound = Coded("ErrFileNotFound", "file (%s) not found", CodeNotFound) ErrIORead = Categorized(CategoryIO, "ErrIORead", "I/O read error: %s") ErrIOWrite = Categorized(CategoryIO, "ErrIOWrite", "I/O write error: %s") ) // Network Errors (Templated) // Example: ErrNetworkTimeout("http://example.com", "no response"). var ( ErrNetworkConnRefused = Categorized(CategoryNetwork, "ErrNetworkConnRefused", "connection refused: %s") ErrNetworkTimeout = Categorized(CategoryNetwork, "ErrNetworkTimeout", "network timeout: %s") ErrNetworkUnreachable = Categorized(CategoryNetwork, "ErrNetworkUnreachable", "network unreachable: %s") ) // System Errors (Templated) // Example: ErrResourceExhausted("memory", "out of memory"). var ( ErrConfigInvalid = Coded("ErrConfigInvalid", "invalid configuration: %s", CodeInternalError) ErrResourceExhausted = Coded("ErrResourceExhausted", "resource exhausted: %s", CodeServiceUnavailable) ErrSystemFailure = Coded("ErrSystemFailure", "system failure: %s", CodeInternalError) ErrSystemUnhealthy = Coded("ErrSystemUnhealthy", "system unhealthy: %s", CodeServiceUnavailable) ) // User Errors (Templated) // Example: ErrUserNotFound("user123", "not in database"). var ( ErrUserLocked = Coded("ErrUserLocked", "user %s is locked: %s", CodeForbidden) ErrUserNotFound = Coded("ErrUserNotFound", "user %s not found: %s", CodeNotFound) ErrUserPermission = Coded("ErrUserPermission", "user %s lacks permission: %s", CodeForbidden) ErrUserSuspended = Coded("ErrUserSuspended", "user %s is suspended: %s", CodeForbidden) ) // Validation Errors (Templated) // Example: ErrValidationFailed("email", "invalid email format"). var ( ErrInvalidFormat = Coded("ErrInvalidFormat", "invalid format: %s", CodeBadRequest) ErrValidationFailed = Coded("ErrValidationFailed", "validation failed: %s", CodeBadRequest) ) // Additional REST API Errors (Templated) // Example: ErrMethodNotAllowed("POST", "only GET allowed"). var ( ErrConflict = Coded("ErrConflict", "conflict occurred: %s", CodeConflict) ErrMethodNotAllowed = Coded("ErrMethodNotAllowed", "method %s not allowed", CodeMethodNotAllowed) ErrNotImplemented = Coded("ErrNotImplemented", "%s not implemented", CodeNotImplemented) ErrRateLimitExceeded = Coded("ErrRateLimitExceeded", "rate limit exceeded: %s", CodeTooManyRequests) ErrServiceUnavailable = Coded("ErrServiceUnavailable", "service (%s) unavailable", CodeServiceUnavailable) ErrUnprocessable = Coded("ErrUnprocessable", "unprocessable entity: %s", CodeUnprocessable) ) // Additional Domain-Specific Errors (Templated) // Example: ErrSerialization("json", "invalid data"). var ( ErrDeserialization = Define("ErrDeserialization", "deserialization error: %s") ErrExternalService = Define("ErrExternalService", "external service (%s) error") ErrSerialization = Define("ErrSerialization", "serialization error: %s") ErrUnsupportedOperation = Coded("ErrUnsupportedOperation", "unsupported operation %s", CodeNotImplemented) ) // Predefined Templates with Categories (Templated) // These are convenience wrappers with categories applied; use like AuthFailed("user", "reason"). var ( AuthFailed = Categorized(CategoryAuth, "AuthFailed", "authentication failed for %s: %s") BusinessError = Categorized(CategoryBusiness, "BusinessError", "business error: %s") DBError = Categorized(CategoryDatabase, "DBError", "database error: %s") IOError = Categorized(CategoryIO, "IOError", "I/O error: %s") NetworkError = Categorized(CategoryNetwork, "NetworkError", "network failure: %s") SystemError = Categorized(CategorySystem, "SystemError", "system error: %s") UserError = Categorized(CategoryUser, "UserError", "user error: %s") ValidationError = Categorized(CategoryValidation, "ValidationError", "validation error: %s") ) golang-github-olekukonko-errors-1.3.0/errmgr/common_test.go000066400000000000000000000121671517267734700241410ustar00rootroot00000000000000package errmgr import ( "github.com/olekukonko/errors" "testing" ) func TestStaticErrors(t *testing.T) { tests := []struct { err *errors.Error name string expected string code int retry bool timeout bool }{ {ErrInvalidArg, "ErrInvalidArg", "invalid argument", CodeBadRequest, false, false}, {ErrNotFound, "ErrNotFound", "not found", CodeNotFound, false, false}, {ErrPermission, "ErrPermission", "permission denied", CodeForbidden, false, false}, {ErrTimeout, "ErrTimeout", "operation timed out", 0, false, true}, {ErrUnknown, "ErrUnknown", "unknown error", CodeInternalError, false, false}, {ErrDBConnRetryable, "ErrDBConnRetryable", "database connection failed", 0, true, false}, {ErrNetworkRetryable, "ErrNetworkRetryable", "network failure", 0, true, false}, {ErrNetworkTimedOut, "ErrNetworkTimedOut", "network timeout", 0, true, true}, {ErrServiceRetryable, "ErrServiceRetryable", "service unavailable", CodeServiceUnavailable, true, false}, {ErrRateLimitRetryable, "ErrRateLimitRetryable", "rate limit exceeded", CodeTooManyRequests, true, false}, } for _, tt := range tests { t.Run(tt.expected, func(t *testing.T) { if tt.err.Error() != tt.expected { t.Errorf("Expected message %q, got %q", tt.expected, tt.err.Error()) } if tt.err.Code() != tt.code { t.Errorf("Expected code %d, got %d", tt.code, tt.err.Code()) } ctx := tt.err.Context() if tt.retry && (ctx == nil || !ctx["[error] retry"].(bool)) { t.Errorf("Expected retryable error, got context %v", ctx) } if tt.timeout && (ctx == nil || !ctx["[error] timeout"].(bool)) { t.Errorf("Expected timeout error, got context %v", ctx) } }) } } func TestTemplatedErrors(t *testing.T) { tests := []struct { errFunc func(...interface{}) *errors.Error name string args []interface{} expected string code int category errors.ErrorCategory }{ {ErrAuthFailed, "ErrAuthFailed", []interface{}{"user", "pass"}, "authentication failed for user: pass", CodeUnauthorized, ""}, {ErrDBConnection, "ErrDBConnection", []interface{}{"mysql"}, "database connection failed: mysql", 0, CategoryDatabase}, {ErrNetworkTimeout, "ErrNetworkTimeout", []interface{}{"host"}, "network timeout: host", 0, CategoryNetwork}, {ErrFileNotFound, "ErrFileNotFound", []interface{}{"file.txt"}, "file (file.txt) not found", CodeNotFound, ""}, {ErrValidationFailed, "ErrValidationFailed", []interface{}{"email"}, "validation failed: email", CodeBadRequest, ""}, {ErrRateLimitExceeded, "ErrRateLimitExceeded", []interface{}{"user123"}, "rate limit exceeded: user123", CodeTooManyRequests, ""}, {ErrUserNotFound, "ErrUserNotFound", []interface{}{"user123", "not in db"}, "user user123 not found: not in db", CodeNotFound, ""}, {ErrMethodNotAllowed, "ErrMethodNotAllowed", []interface{}{"POST"}, "method POST not allowed", CodeMethodNotAllowed, ""}, {ErrUnprocessable, "ErrUnprocessable", []interface{}{"data"}, "unprocessable entity: data", CodeUnprocessable, ""}, {ErrBusinessRule, "ErrBusinessRule", []interface{}{"rule1"}, "business rule violation: rule1", 0, CategoryBusiness}, {ErrIORead, "ErrIORead", []interface{}{"disk"}, "I/O read error: disk", 0, CategoryIO}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := tt.errFunc(tt.args...) if err.Error() != tt.expected { t.Errorf("Expected message %q, got %q", tt.expected, err.Error()) } if err.Code() != tt.code { t.Errorf("Expected code %d, got %d", tt.code, err.Code()) } if tt.category != "" { if cat := errors.Category(err); cat != string(tt.category) { t.Errorf("Expected category %q, got %q", tt.category, cat) } } err.Free() }) } } func TestCategorizedTemplates(t *testing.T) { tests := []struct { errFunc func(...interface{}) *errors.Error name string args []interface{} expected string category errors.ErrorCategory code int }{ {AuthFailed, "AuthFailed", []interface{}{"user", "reason"}, "authentication failed for user: reason", CategoryAuth, 0}, {BusinessError, "BusinessError", []interface{}{"rule"}, "business error: rule", CategoryBusiness, 0}, {DBError, "DBError", []interface{}{"query"}, "database error: query", CategoryDatabase, 0}, {IOError, "IOError", []interface{}{"disk"}, "I/O error: disk", CategoryIO, 0}, {NetworkError, "NetworkError", []interface{}{"host"}, "network failure: host", CategoryNetwork, 0}, {SystemError, "SystemError", []interface{}{"crash"}, "system error: crash", CategorySystem, 0}, {UserError, "UserError", []interface{}{"input"}, "user error: input", CategoryUser, 0}, {ValidationError, "ValidationError", []interface{}{"format"}, "validation error: format", CategoryValidation, 0}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := tt.errFunc(tt.args...) if err.Error() != tt.expected { t.Errorf("Expected message %q, got %q", tt.expected, err.Error()) } if err.Code() != tt.code { t.Errorf("Expected code %d, got %d", tt.code, err.Code()) } if cat := errors.Category(err); cat != string(tt.category) { t.Errorf("Expected category %q, got %q", tt.category, cat) } err.Free() }) } } golang-github-olekukonko-errors-1.3.0/errmgr/errmgr.go000066400000000000000000000202101517267734700230740ustar00rootroot00000000000000// Package errmgr provides functionality for managing error templates, counts, thresholds, // and alerts in a thread-safe manner, building on the core errors package. package errmgr import ( "fmt" "github.com/olekukonko/errors" "strings" "sync" "sync/atomic" ) // Config holds configuration for the errmgr package. type Config struct { DisableMetrics bool // Disables counting and tracking if true } // cachedConfig holds the current configuration, updated only on Configure(). type cachedConfig struct { disableErrMgr bool } var ( currentConfig cachedConfig configMu sync.RWMutex registry = errorRegistry{counts: shardedCounter{}} codes = codeRegistry{m: make(map[string]int)} ) func init() { currentConfig = cachedConfig{disableErrMgr: false} } // errorRegistry holds registered errors and their metadata. type errorRegistry struct { templates sync.Map // map[string]string: Error templates funcs sync.Map // map[string]func(...interface{}) *errors.Error: Custom error functions counts shardedCounter // Sharded counter for error occurrences thresholds sync.Map // map[string]uint64: Alert thresholds alerts sync.Map // map[string]*alertChannel: Alert channels mu sync.RWMutex // Protects alerts map } // codeRegistry manages error codes with explicit locking. type codeRegistry struct { m map[string]int mu sync.RWMutex } // shardedCounter provides a low-contention counter for error occurrences. type shardedCounter struct { counts sync.Map } // Categorized creates a categorized error template and returns a function to create errors. // The returned function applies the category to each error instance. func Categorized(category errors.ErrorCategory, name, template string) func(...interface{}) *errors.Error { f := Define(name, template) return func(args ...interface{}) *errors.Error { return f(args...).WithCategory(category) } } // CloseMonitor closes the alert channel for a specific error name. // Thread-safe; subsequent alerts for this name are ignored. func CloseMonitor(name string) { registry.mu.Lock() defer registry.mu.Unlock() if ch, ok := registry.alerts.Load(name); ok { ac := ch.(*alertChannel) ac.mu.Lock() if !ac.closed { close(ac.ch) ac.closed = true } ac.mu.Unlock() registry.alerts.Delete(name) } } // Coded creates a templated error with a specific HTTP status code. // It wraps Define and applies the code to each error instance returned. func Coded(name, template string, code int) func(...interface{}) *errors.Error { codes.mu.Lock() codes.m[name] = code codes.mu.Unlock() base := Define(name, template) return func(args ...interface{}) *errors.Error { err := base(args...) return err.WithCode(code) } } // Configure updates the global configuration for the errmgr package. // Thread-safe; applies immediately to all subsequent operations. func Configure(cfg Config) { configMu.Lock() currentConfig = cachedConfig{disableErrMgr: cfg.DisableMetrics} configMu.Unlock() } // Copy creates a new instance of a predefined static error, ensuring immutability of originals. // Use this for static errors; templated errors should be called directly with arguments. func Copy(err *errors.Error) *errors.Error { return err.Copy() } // Define creates a templated error that formats a message with provided arguments. // The error is tracked in the registry if error management is enabled. func Define(name, template string) func(...interface{}) *errors.Error { registry.templates.Store(name, template) if !currentConfig.disableErrMgr { registry.counts.RegisterName(name) } return func(args ...interface{}) *errors.Error { var buf strings.Builder buf.Grow(len(template) + len(name) + len(args)*10) fmt.Fprintf(&buf, template, args...) err := errors.New(buf.String()).WithName(name).WithTemplate(template) if !currentConfig.disableErrMgr { registry.counts.Inc(name) } return err } } // GetThreshold returns the current threshold for an error name, if set. // Returns 0 and false if no threshold is defined. func GetThreshold(name string) (uint64, bool) { if thresh, ok := registry.thresholds.Load(name); ok { return thresh.(uint64), true } return 0, false } // Inc increments the counter for a specific name in a shard and checks thresholds. // Returns the new count for the shard; use Value() for the total count. func (c *shardedCounter) Inc(name string) uint64 { countPtr, _ := c.counts.LoadOrStore(name, new(uint64)) count := countPtr.(*uint64) newCount := atomic.AddUint64(count, 1) if thresh, ok := registry.thresholds.Load(name); ok { total := atomic.LoadUint64(count) if total >= thresh.(uint64) { if ch, ok := registry.alerts.Load(name); ok { ac := ch.(*alertChannel) ac.mu.Lock() if !ac.closed { alert := errors.New(fmt.Sprintf("%s count exceeded threshold: %d", name, total)). WithName(name) for i := uint64(0); i < total; i++ { _ = alert.Increment() } select { case ac.ch <- alert: default: // Drop if channel is full } } ac.mu.Unlock() } } } return newCount } // ListNames returns all registered error names in the counter. // Thread-safe; returns an empty slice if no names are registered. func (c *shardedCounter) ListNames() []string { var names []string c.counts.Range(func(key, _ interface{}) bool { names = append(names, key.(string)) return true }) return names } // Metrics returns a snapshot of error counts for monitoring systems. // Returns nil if error management is disabled or no counts exist. func Metrics() map[string]uint64 { if currentConfig.disableErrMgr { return nil } counts := make(map[string]uint64) registry.counts.counts.Range(func(key, value interface{}) bool { name := key.(string) count := registry.counts.Value(name) if count > 0 { counts[name] = count } return true }) if len(counts) == 0 { return nil } return counts } // RegisterName ensures a counter exists for the name without incrementing it. // Thread-safe; useful for pre-registering error names. func (c *shardedCounter) RegisterName(name string) { c.counts.LoadOrStore(name, new(uint64)) } // RemoveThreshold removes the threshold for a specific error name. // Thread-safe; no effect if no threshold exists. func RemoveThreshold(name string) { registry.thresholds.Delete(name) } // Reset clears all counters and removes their registrations. // Has no effect if error management is disabled. func Reset() { if currentConfig.disableErrMgr { return } registry.counts.counts.Range(func(key, _ interface{}) bool { registry.counts.Reset(key.(string)) registry.counts.counts.Delete(key) return true }) } // ResetCounter resets the occurrence counter for a specific error type. // Has no effect if error management is disabled or the name isn’t registered. func ResetCounter(name string) { if !currentConfig.disableErrMgr { registry.counts.Reset(name) } } // Reset resets the counter for a specific name across all shards. // Thread-safe; no effect if the name isn’t registered. func (c *shardedCounter) Reset(name string) { if countPtr, ok := c.counts.Load(name); ok { atomic.StoreUint64(countPtr.(*uint64), 0) } } // SetThreshold sets a count threshold for an error name, triggering alerts when exceeded. // Alerts are sent to the Monitor channel if one exists for the name. func SetThreshold(name string, threshold uint64) { registry.thresholds.Store(name, threshold) } // Tracked registers a custom error function and tracks its occurrences in the registry. // The returned function increments the error count each time it is called. func Tracked(name string, fn func(...interface{}) *errors.Error) func(...interface{}) *errors.Error { registry.funcs.Store(name, fn) if !currentConfig.disableErrMgr { registry.counts.RegisterName(name) } return func(args ...interface{}) *errors.Error { if !currentConfig.disableErrMgr { registry.counts.Inc(name) } return fn(args...) } } // Value returns the total count for a specific name across all shards. // Thread-safe; returns 0 if the name isn’t registered. func (c *shardedCounter) Value(name string) uint64 { if countPtr, ok := c.counts.Load(name); ok { return atomic.LoadUint64(countPtr.(*uint64)) } return 0 } golang-github-olekukonko-errors-1.3.0/errmgr/errmgr_benchmark_test.go000066400000000000000000000033411517267734700261530ustar00rootroot00000000000000package errmgr import ( "fmt" "github.com/olekukonko/errors" "testing" ) // BenchmarkTemplateError measures the performance of creating templated errors. func BenchmarkTemplateError(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { err := ErrDBConnection(fmt.Sprintf("connection failed %d", i)) err.Free() } } // BenchmarkCodedError measures the performance of creating coded errors. func BenchmarkCodedError(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { err := ErrValidationFailed(fmt.Sprintf("field %d", i)) err.Free() } } // BenchmarkCategorizedError measures the performance of creating categorized errors. func BenchmarkCategorizedError(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { err := NetworkError(fmt.Sprintf("host %d", i)) err.Free() } } // BenchmarkCallableError measures the performance of creating custom callable errors. func BenchmarkCallableError(b *testing.B) { fn := Tracked("custom", func(args ...interface{}) *errors.Error { return errors.New(fmt.Sprintf("custom %v", args[0])) }) b.ResetTimer() for i := 0; i < b.N; i++ { err := fn(i) err.Free() } } // BenchmarkMetrics measures the performance of retrieving error metrics. func BenchmarkMetrics(b *testing.B) { for i := 0; i < 100; i++ { err := ErrDBConnection(fmt.Sprintf("test %d", i)) err.Free() } b.ResetTimer() for i := 0; i < b.N; i++ { _ = Metrics() } } func BenchmarkMonitorWithClosedChannel(b *testing.B) { Reset() SetThreshold("BenchError", 1) // Create and close monitor to test closed channel case monitor := NewMonitor("BenchError") monitor.Close() errFunc := Define("BenchError", "bench test %d") b.ResetTimer() for i := 0; i < b.N; i++ { err := errFunc(i) err.Free() } } golang-github-olekukonko-errors-1.3.0/errmgr/errmgr_test.go000066400000000000000000000055671517267734700241550ustar00rootroot00000000000000package errmgr import ( "fmt" "github.com/olekukonko/errors" "testing" ) func TestMain(m *testing.M) { errors.Configure(errors.Config{ StackDepth: 32, ContextSize: 2, DisablePooling: false, FilterInternal: true, }) Configure(Config{DisableMetrics: false}) errors.WarmPool(10) errors.WarmStackPool(10) m.Run() } func TestDefine(t *testing.T) { ResetCounter("test_tmpl") tmpl := Define("test_tmpl", "test error: %s") err := tmpl("detail") defer err.Free() if err.Error() != "test error: detail" { t.Errorf("Define() error = %v, want %v", err.Error(), "test error: detail") } if err.Name() != "test_tmpl" { t.Errorf("Define() name = %v, want %v", err.Name(), "test_tmpl") } if Metrics()["test_tmpl"] != 1 { t.Errorf("Metrics()[test_tmpl] = %d, want 1", Metrics()["test_tmpl"]) } } func TestCallable(t *testing.T) { ResetCounter("test_call") fn := Tracked("test_call", func(args ...interface{}) *errors.Error { return errors.Named("test_call").Msgf("called with %v", args[0]) }) err := fn("arg1") defer err.Free() if err.Error() != "called with arg1" { t.Errorf("Callable() error = %v, want %v", err.Error(), "called with arg1") } if Metrics()["test_call"] != 1 { t.Errorf("Metrics()[test_call] = %d, want 1", Metrics()["test_call"]) } } func TestCoded(t *testing.T) { ResetCounter("test_coded") tmpl := Coded("test_coded", "coded error: %s", 400) err := tmpl("reason") defer err.Free() if err.Error() != "coded error: reason" { t.Errorf("Coded() error = %v, want %v", err.Error(), "coded error: reason") } if err.Code() != 400 { t.Errorf("Coded() code = %d, want 400", err.Code()) } if Metrics()["test_coded"] != 1 { t.Errorf("Metrics()[test_coded] = %d, want 1", Metrics()["test_coded"]) } } func TestMetrics(t *testing.T) { Reset() ResetCounter("metric1") ResetCounter("metric2") tmpl1 := Define("metric1", "metric one: %s") tmpl2 := Define("metric2", "metric two: %s") for i := 0; i < 3; i++ { err := tmpl1(fmt.Sprintf("test%d", i)) err.Free() } for i := 0; i < 2; i++ { err := tmpl2(fmt.Sprintf("test%d", i)) err.Free() } metrics := Metrics() if len(metrics) != 2 { t.Errorf("Metrics() len = %d, want 2", len(metrics)) } if metrics["metric1"] != 3 { t.Errorf("Metrics()[metric1] = %d, want 3", metrics["metric1"]) } if metrics["metric2"] != 2 { t.Errorf("Metrics()[metric2] = %d, want 2", metrics["metric2"]) } } func TestCountReset(t *testing.T) { name := "test_reset" ResetCounter(name) tmpl := Define(name, "reset test") for i := 0; i < 5; i++ { err := tmpl("test") err.Free() } err := tmpl("before reset") defer err.Free() if Metrics()[name] != 6 { t.Errorf("Metrics()[%s] before reset = %d, want 6", name, Metrics()[name]) } ResetCounter(name) err2 := tmpl("after reset") defer err2.Free() if Metrics()[name] != 1 { t.Errorf("Metrics()[%s] after reset = %d, want 1", name, Metrics()[name]) } } golang-github-olekukonko-errors-1.3.0/errmgr/monitor.go000066400000000000000000000055341517267734700233010ustar00rootroot00000000000000// Package errmgr provides error monitoring functionality. package errmgr import ( "github.com/olekukonko/errors" "sync" ) const ( monitorSize = 10 ) // alertChannel wraps a channel with synchronization for safe closure. // Used internally by Monitor to manage alert delivery. type alertChannel struct { ch chan *errors.Error closed bool mu sync.Mutex } // Monitor represents an error monitoring channel for a specific error name. // It receives alerts when the error count exceeds a configured threshold set via SetThreshold. type Monitor struct { name string ac *alertChannel } // Alerts returns the channel for receiving error alerts. // Alerts are sent when the error count exceeds the threshold set by SetThreshold. // Returns nil if the monitor has been closed. func (m *Monitor) Alerts() <-chan *errors.Error { m.ac.mu.Lock() defer m.ac.mu.Unlock() if m.ac.closed { return nil } return m.ac.ch } // Close shuts down the monitor channel and removes it from the registry. // Thread-safe and idempotent; subsequent calls have no effect. func (m *Monitor) Close() { registry.mu.Lock() defer registry.mu.Unlock() if existing, ok := registry.alerts.Load(m.name); ok { if ac, ok := existing.(*alertChannel); ok && ac == m.ac { ac.mu.Lock() if !ac.closed { close(ac.ch) ac.closed = true } ac.mu.Unlock() registry.alerts.Delete(m.name) } } } // IsClosed reports whether the monitor’s channel has been closed. // Thread-safe; useful for checking monitor status before use. func (m *Monitor) IsClosed() bool { m.ac.mu.Lock() defer m.ac.mu.Unlock() return m.ac.closed } // NewMonitor creates a new Monitor for the given error name with a default buffer of 10. // Reuses an existing channel if one is already registered; thread-safe. // Use NewMonitorBuffered for a custom buffer size. func NewMonitor(name string) *Monitor { registry.mu.Lock() defer registry.mu.Unlock() if existing, ok := registry.alerts.Load(name); ok { return &Monitor{name: name, ac: existing.(*alertChannel)} } ac := &alertChannel{ ch: make(chan *errors.Error, monitorSize), closed: false, } registry.alerts.Store(name, ac) return &Monitor{name: name, ac: ac} } // NewMonitorBuffered creates a new Monitor for the given error name with a specified buffer size. // Reuses an existing channel if one is already registered; thread-safe. // Buffer must be non-negative (0 means unbuffered); use NewMonitor for the default buffer of 10. func NewMonitorBuffered(name string, buffer int) *Monitor { if buffer < 0 { buffer = 0 } registry.mu.Lock() defer registry.mu.Unlock() if existing, ok := registry.alerts.Load(name); ok { return &Monitor{name: name, ac: existing.(*alertChannel)} } ac := &alertChannel{ ch: make(chan *errors.Error, buffer), closed: false, } registry.alerts.Store(name, ac) return &Monitor{name: name, ac: ac} } golang-github-olekukonko-errors-1.3.0/errmgr/monitor_test.go000066400000000000000000000123261517267734700243350ustar00rootroot00000000000000package errmgr import ( "strings" "sync" "testing" "time" ) func TestMonitorAlerts(t *testing.T) { Reset() monitor := NewMonitor("TestError") SetThreshold("TestError", 2) defer monitor.Close() errFunc := Define("TestError", "test error %d") for i := 0; i < 3; i++ { err := errFunc(i) if err.Name() != "TestError" { t.Errorf("Expected error name 'TestError', got %q", err.Name()) } err.Free() } select { case alert := <-monitor.Alerts(): if alert == nil { t.Fatal("Received nil alert after threshold exceeded") } if alert.Name() != "TestError" { t.Errorf("Expected alert name 'TestError', got %q", alert.Name()) } if alert.Count() < 2 { t.Errorf("Expected alert count >= 2, got %d", alert.Count()) } if !strings.Contains(alert.Error(), "threshold") { t.Errorf("Expected threshold message in alert, got %q", alert.Error()) } case <-time.After(100 * time.Millisecond): t.Error("No alert received within 100ms timeout") } } func TestMonitorBuffered(t *testing.T) { Reset() monitor := NewMonitorBuffered("BufferedError", 2) // Buffer size 2 SetThreshold("BufferedError", 1) defer monitor.Close() errFunc := Define("BufferedError", "buffered error %d") var wg sync.WaitGroup wg.Add(1) // Single goroutine go func() { defer wg.Done() for i := 0; i < 4; i++ { // Generate 4 errors err := errFunc(i) t.Logf("Generated error %d, count now %d", i, registry.counts.Value("BufferedError")) err.Free() time.Sleep(10 * time.Millisecond) // Slow down to fill buffer } }() // Wait for all errors to be generated wg.Wait() // Check metrics to confirm all 4 errors were counted counts := Metrics() if count, ok := counts["BufferedError"]; !ok || count != 4 { t.Errorf("Expected count 4 for BufferedError, got %v", counts) } // Consume alerts (expect up to 2 due to buffer size) received := 0 timeout := time.After(200 * time.Millisecond) for received < 2 { // Expect at least 2 alerts select { case alert := <-monitor.Alerts(): if alert == nil { t.Fatal("Received nil alert") } received++ t.Logf("Received alert %d: %s", received, alert.Error()) if alert.Name() != "BufferedError" { t.Errorf("Expected alert name 'BufferedError', got %q", alert.Name()) } case <-timeout: t.Logf("Timeout waiting for alerts; received %d", received) break // Allow partial success if buffer limited alerts } } } func TestMonitorChannelCloseRace(t *testing.T) { Reset() SetThreshold("RaceError", 1) // Create and immediately close monitor to simulate quick close monitor := NewMonitor("RaceError") monitor.Close() // Ensure no panic when sending to closed channel errFunc := Define("RaceError", "race test %d") for i := 0; i < 3; i++ { err := errFunc(i) err.Free() } // Create new monitor and verify it works newMonitor := NewMonitor("RaceError") defer newMonitor.Close() err := errFunc(42) err.Free() select { case alert := <-newMonitor.Alerts(): if alert == nil { t.Fatal("Received nil alert after reopening monitor") } if alert.Name() != "RaceError" { t.Errorf("Expected alert name 'RaceError', got %q", alert.Name()) } if alert.Count() < 1 { t.Errorf("Expected alert count >= 1, got %d", alert.Count()) } case <-time.After(100 * time.Millisecond): t.Error("No alert received within 100ms timeout") } if !monitor.IsClosed() { t.Error("Original monitor should be closed") } if newMonitor.IsClosed() { t.Error("New monitor should not be closed yet") } } func TestMonitorIsClosed(t *testing.T) { Reset() monitor := NewMonitor("CloseTest") if monitor.IsClosed() { t.Error("New monitor should not be closed") } monitor.Close() if !monitor.IsClosed() { t.Error("Monitor should be closed after Close()") } if ch := monitor.Alerts(); ch != nil { t.Error("Alerts should return nil after closure") } } func TestMonitorMultipleInstances(t *testing.T) { Reset() monitor1 := NewMonitor("MultiTest") monitor2 := NewMonitor("MultiTest") // Shares the same channel SetThreshold("MultiTest", 1) defer monitor1.Close() errFunc := Define("MultiTest", "multi test %d") err := errFunc(1) err.Free() // Consume from monitor1, expect monitor2 to see no alerts (single channel) select { case alert1 := <-monitor1.Alerts(): if alert1 == nil { t.Fatal("Received nil alert from monitor1") } if alert1.Name() != "MultiTest" { t.Errorf("Expected alert name 'MultiTest', got %q", alert1.Name()) } case <-time.After(100 * time.Millisecond): t.Error("No alert received from monitor1 within timeout") } // Verify monitor2 doesn't receive the same alert (already consumed) select { case alert2 := <-monitor2.Alerts(): t.Errorf("Unexpected alert from monitor2: %v (channel should be drained)", alert2) case <-time.After(50 * time.Millisecond): // Expected: no alert since monitor1 consumed it } // Generate another error to ensure both monitors share the same channel err = errFunc(2) err.Free() select { case alert2 := <-monitor2.Alerts(): if alert2 == nil { t.Fatal("Received nil alert from monitor2") } if alert2.Name() != "MultiTest" { t.Errorf("Expected alert name 'MultiTest', got %q", alert2.Name()) } case <-time.After(100 * time.Millisecond): t.Error("No alert received from monitor2 within timeout") } } golang-github-olekukonko-errors-1.3.0/errors.go000066400000000000000000001146501517267734700216300ustar00rootroot00000000000000// Package errors provides a robust error handling library with support for // error wrapping, stack traces, context storage, and retry mechanisms. It extends // the standard library's error interface with features like HTTP-like status codes, // error categorization, and JSON serialization, while maintaining compatibility // with `errors.Is`, `errors.As`, and `errors.Unwrap`. The package is thread-safe // and optimized with object pooling for performance. package errors import ( "bytes" "encoding/json" "errors" "fmt" "log/slog" "runtime" "strings" "sync" "sync/atomic" ) // Error is a custom error type with enhanced features: message, name, stack trace, // context, cause, and metadata like code and category. It is thread-safe and // supports pooling for performance. type Error struct { // Fields used in atomic operations. Place them at the beginning of the // struct to ensure proper alignment across all architectures. count uint64 // Occurrence count for tracking frequency. // Primary fields (frequently accessed). msg string // The error message displayed by Error(). name string // The error name or type (e.g., "AuthError"). stack []uintptr // Stack trace as program counters. // Secondary metadata. template string // Fallback message template if msg is empty. category string // Error category (e.g., "network"). code int32 // HTTP-like status code (e.g., 400, 500). smallCount int32 // Number of items in smallContext. // Context and chaining. context map[string]interface{} // Key-value pairs for additional context. cause error // Wrapped underlying error for chaining. callback func() // Optional callback invoked by Error(). smallContext [contextSize]contextItem // Fixed-size array for small contexts. // Synchronization. mu sync.RWMutex // Protects mutable fields (context, smallContext). // Internal flags. formatWrapped bool // True if created by Newf with %w verb. } // newError creates a new Error instance, reusing from the pool if enabled. // Initializes smallContext and sets stack to nil. // Internal use; prefer New, Named, or Trace for public API. func newError() *Error { if currentConfig.disablePooling { return &Error{ smallContext: [contextSize]contextItem{}, stack: nil, } } return errorPool.Get() } // Empty returns a new empty error with no message, name, or stack trace. // Useful for incrementally building errors or as a neutral base. // Example: // // err := errors.Empty().With("key", "value").WithCode(400) func Empty() *Error { return newError() } // Named creates an error with the specified name and captures a stack trace. // The name doubles as the error message if no message is set. // Use for errors where type identification and stack context are important. // Example: // // err := errors.Named("AuthError").WithCode(401) func Named(name string) *Error { e := newError() e.name = name return e.WithStack() } // New creates a lightweight error with the given message and no stack trace. // Optimized for performance; use Trace() for stack traces. // Returns a shared empty error for empty messages to reduce allocations. // Example: // // err := errors.New("invalid input") func New(text string) *Error { if text == "" { return emptyError.Copy() // Avoid modifying shared instance. } err := newError() err.msg = text return err } // Newf creates a formatted error, supporting the %w verb for wrapping errors. // If the format contains exactly one %w verb with a non-nil error argument, // the error is wrapped as the cause. The final error message string generated // by Error() will be compatible with the output of fmt.Errorf for the same inputs. // Does not capture a stack trace by default. // Example: // // cause := errors.New("db error") // err := errors.Newf("query failed: %w", cause) // // err.Error() will match fmt.Errorf("query failed: %w", cause).Error() // // errors.Unwrap(err) == cause func Newf(f any, args ...interface{}) *Error { var format string switch v := f.(type) { case string: format = v case fmt.Stringer: format = v.String() default: panic("Newf: format must be a string or fmt.Stringer") } err := newError() var wCount int var wArgPos = -1 var wArg error var validationErrorMsg string argPos := 0 runes := []rune(format) i := 0 parsingOk := true var fmtVerbs []struct { isW bool spec string // The full verb specifier or literal segment argIdx int // Index in the original 'args' slice, -1 for literals/%% } // Parse format string to identify verbs and literals. for i < len(runes) && parsingOk { segmentStart := i if runes[i] == '%' { if i+1 >= len(runes) { parsingOk = false validationErrorMsg = "ends with %" break } if runes[i+1] == '%' { fmtVerbs = append(fmtVerbs, struct { isW bool spec string argIdx int }{isW: false, spec: "%%", argIdx: -1}) i += 2 continue } i++ // Move past '%' // Parse flags, width, precision (simplified loop) for i < len(runes) && strings.ContainsRune("+- #0", runes[i]) { i++ } for i < len(runes) && ((runes[i] >= '0' && runes[i] <= '9') || runes[i] == '.') { i++ } if i >= len(runes) { parsingOk = false validationErrorMsg = "ends mid-specifier" break } verb := runes[i] specifierEndIndex := i + 1 fullSpec := string(runes[segmentStart:specifierEndIndex]) // Check if the verb consumes an argument currentVerbConsumesArg := strings.ContainsRune("vTtbcdoqxXUeEfFgGspw", verb) currentArgIdx := -1 isWVerb := false if verb == 'w' { isWVerb = true wCount++ if wCount == 1 { wArgPos = argPos // Record position of the error argument } else { parsingOk = false validationErrorMsg = "multiple %w" break } } if currentVerbConsumesArg { if argPos >= len(args) { parsingOk = false if isWVerb { // More specific message for missing %w arg validationErrorMsg = "missing %w argument" } else { validationErrorMsg = fmt.Sprintf("missing argument for %s", string(verb)) } break } currentArgIdx = argPos if isWVerb { cause, ok := args[argPos].(error) if !ok || cause == nil { parsingOk = false validationErrorMsg = "bad %w argument type" break } wArg = cause // Store the actual error argument } argPos++ // Consume the argument position } fmtVerbs = append(fmtVerbs, struct { isW bool spec string argIdx int }{isW: isWVerb, spec: fullSpec, argIdx: currentArgIdx}) i = specifierEndIndex // Move past the verb character } else { // Handle literal segment literalStart := i for i < len(runes) && runes[i] != '%' { i++ } fmtVerbs = append(fmtVerbs, struct { isW bool spec string argIdx int }{isW: false, spec: string(runes[literalStart:i]), argIdx: -1}) } } // Check for too many arguments after parsing if parsingOk && argPos < len(args) { parsingOk = false validationErrorMsg = fmt.Sprintf("too many arguments for format %q", format) } // Handle format validation errors. if !parsingOk { switch validationErrorMsg { case "multiple %w": err.msg = fmt.Sprintf("errors.Newf: format %q has multiple %%w verbs", format) case "missing %w argument": err.msg = fmt.Sprintf("errors.Newf: format %q has %%w but not enough arguments", format) case "bad %w argument type": argValStr := "()" if wArgPos >= 0 && wArgPos < len(args) && args[wArgPos] != nil { argValStr = fmt.Sprintf("(%T)", args[wArgPos]) } else if wArgPos >= len(args) { argValStr = "(missing)" // Should be caught by "missing %w argument" case } err.msg = fmt.Sprintf("errors.Newf: argument %d for %%w is not a non-nil error %s", wArgPos, argValStr) case "ends with %": err.msg = fmt.Sprintf("errors.Newf: format %q ends with %%", format) case "ends mid-specifier": err.msg = fmt.Sprintf("errors.Newf: format %q ends during verb specifier", format) default: // Includes "too many arguments" and other potential fmt issues err.msg = fmt.Sprintf("errors.Newf: error in format %q: %s", format, validationErrorMsg) } err.cause = nil // Ensure no cause is set on format error err.formatWrapped = false return err } // Start: Processing Valid Format String if wCount == 1 && wArg != nil { // Handle %w: Simulate for Sprintf and pre-format err.cause = wArg // Set the cause for unwrapping err.formatWrapped = true // Signal that msg is the final formatted string var finalFormat strings.Builder var finalArgs []interface{} causeStr := wArg.Error() // Get the string representation of the cause // Rebuild format string and argument list for Sprintf for _, verb := range fmtVerbs { if verb.isW { // Replace the %w verb specifier (e.g., "%w", "%+w") with "%s" finalFormat.WriteString("%s") // Add the cause's *string* to the arguments list for the new %s finalArgs = append(finalArgs, causeStr) } else { // Keep the original literal segment or non-%w verb specifier finalFormat.WriteString(verb.spec) if verb.argIdx != -1 { // Add the original argument for this non-%w verb/literal finalArgs = append(finalArgs, args[verb.argIdx]) } } } // Format using the *modified* format string and arguments list result, fmtErr := FmtErrorCheck(finalFormat.String(), finalArgs...) if fmtErr != nil { // Handle potential errors during the final formatting step // This is unlikely if parsing passed, but possible with complex verbs/args err.msg = fmt.Sprintf("errors.Newf: formatting error during %%w simulation for format %q: %v", format, fmtErr) err.cause = nil // Don't keep the cause if final formatting failed err.formatWrapped = false } else { // Store the final, fully formatted string, matching fmt.Errorf output err.msg = result } // End %w Simulation } else { // No %w or wArg is nil: Format directly (original logic) result, fmtErr := FmtErrorCheck(format, args...) if fmtErr != nil { err.msg = fmt.Sprintf("errors.Newf: formatting error for format %q: %v", format, fmtErr) err.cause = nil err.formatWrapped = false } else { err.msg = result err.formatWrapped = false // Ensure false if no %w was involved } } // End: Processing Valid Format String return err } // Errorf is an alias for Newf, providing a familiar interface compatible with // fmt.Errorf. It creates a formatted error without capturing a stack trace. // See Newf for full details on formatting, including %w support for error wrapping. // // Example: // // err := errors.Errorf("failed: %w", errors.New("cause")) // // err.Error() == "failed: cause" func Errorf(format string, args ...interface{}) *Error { return Newf(format, args...) } // Std creates a standard error using errors.New for compatibility. // Does not capture stack traces or add context. // Example: // // err := errors.Std("simple error") func Std(text string) error { return errors.New(text) } // Stdf creates a formatted standard error using fmt.Errorf for compatibility. // Supports %w for wrapping; does not capture stack traces. // Example: // // err := errors.Stdf("failed: %w", cause) func Stdf(format string, a ...interface{}) error { return fmt.Errorf(format, a...) } // Trace creates an error with the given message and captures a stack trace. // Use when debugging context is needed; for performance, prefer New(). // Example: // // err := errors.Trace("operation failed") func Trace(text string) *Error { e := New(text) return e.WithStack() } // Tracef creates a formatted error with a stack trace. // Supports %w for wrapping errors. // Example: // // err := errors.Tracef("query %s failed: %w", query, cause) func Tracef(format string, args ...interface{}) *Error { e := Newf(format, args...) return e.WithStack() } // As attempts to assign the error or one in its chain to the target interface. // Supports *Error and standard error types, traversing the cause chain. // Returns true if successful. // Example: // // var target *Error // if errors.As(err, &target) { // fmt.Println(target.Name()) // } func (e *Error) As(target interface{}) bool { if e == nil { return false } // Handle **Error target (i.e. caller passed &myErrPtr where myErrPtr is *Error). // Traverse the chain and return the first *Error that has a name; if none has a // name, return the first *Error in the chain. This satisfies both: // - TestErrorAs: wraps Named("target") -> finds it by name // - TestErrorFullChain: finds Named("AuthError") deep in the chain if targetPtr, ok := target.(**Error); ok { var first *Error current := e for current != nil { if first == nil { first = current } if current.name != "" { *targetPtr = current return true } if next, ok := current.cause.(*Error); ok { current = next } else if current.cause != nil { return errors.As(current.cause, target) } else { break } } if first != nil { *targetPtr = first return true } return false } // Handle *error target. if targetErr, ok := target.(*error); ok { innermost := error(e) current := error(e) for current != nil { if err, ok := current.(*Error); ok && err.cause != nil { current = err.cause innermost = current } else { break } } *targetErr = innermost return true } // Delegate to cause for other types. if e.cause != nil { return errors.As(e.cause, target) } return false } // Callback sets a function to be called when Error() is invoked. // Useful for logging or side effects on error access. // Example: // // err := errors.New("test").Callback(func() { log.Println("error accessed") }) func (e *Error) Callback(fn func()) *Error { e.callback = fn return e } // Category returns the error’s category, if set. // Example: // // if err.Category() == "network" { // handleNetworkError(err) // } func (e *Error) Category() string { return e.category } // Code returns the error’s HTTP-like status code, if set. // Returns 0 if no code is set. // Example: // // if err.Code() == 404 { // renderNotFound() // } func (e *Error) Code() int { return int(e.code) } // Context returns the error’s context as a map, merging smallContext and map-based context. // Thread-safe; lazily initializes the map if needed. // Example: // // ctx := err.Context() // if userID, ok := ctx["user_id"]; ok { // fmt.Println(userID) // } func (e *Error) Context() map[string]interface{} { e.mu.RLock() defer e.mu.RUnlock() if e.smallCount > 0 && e.context == nil { e.context = make(map[string]interface{}, e.smallCount) for i := int32(0); i < e.smallCount; i++ { e.context[e.smallContext[i].key] = e.smallContext[i].value } } return e.context } // Copy creates a deep copy of the error, preserving all fields except stack freshness. // The new error can be modified independently. // Example: // // newErr := err.Copy().With("new_key", "value") func (e *Error) Copy() *Error { if e == emptyError { return &Error{ smallContext: [contextSize]contextItem{}, } } newErr := newError() newErr.msg = e.msg newErr.name = e.name newErr.template = e.template newErr.cause = e.cause newErr.code = e.code newErr.category = e.category newErr.count = e.count newErr.callback = e.callback // was silently dropped by Copy newErr.formatWrapped = e.formatWrapped // was silently dropped by Copy if e.smallCount > 0 { newErr.smallCount = e.smallCount for i := int32(0); i < e.smallCount; i++ { newErr.smallContext[i] = e.smallContext[i] } } else if e.context != nil { newErr.context = make(map[string]interface{}, len(e.context)) for k, v := range e.context { newErr.context[k] = v } } if e.stack != nil && len(e.stack) > 0 { if newErr.stack == nil { newErr.stack = stackPool.Get().([]uintptr) } newErr.stack = append(newErr.stack[:0], e.stack...) } return newErr } // Count returns the number of times the error has been incremented. // Useful for tracking error frequency. // Example: // // fmt.Printf("Error occurred %d times", err.Count()) func (e *Error) Count() uint64 { return e.count } // Err returns the error as an error interface. // Useful for type assertions or interface compatibility. // Example: // // var stdErr error = err.Err() func (e *Error) Err() error { return e } // Error returns the string representation of the error. // If the error was created using Newf/Errorf with the %w verb, it returns the // pre-formatted string compatible with fmt.Errorf. // Otherwise, it combines the message, template, or name with the cause's error // string, separated by ": ". Invokes any set callback. func (e *Error) Error() string { if e.callback != nil { e.callback() } // If created by Newf/Errorf with %w, msg already contains the final string. if e.formatWrapped { return e.msg // Return the pre-formatted fmt.Errorf-compatible string } // Original logic for errors not created via Newf("%w", ...) // or errors created via New/Named and then Wrap() called. var buf strings.Builder // Append primary message part (msg, template, or name) if e.msg != "" { buf.WriteString(e.msg) } else if e.template != "" { buf.WriteString(e.template) } else if e.name != "" { buf.WriteString(e.name) } // Append cause if it exists (only relevant if not formatWrapped) if e.cause != nil { if buf.Len() > 0 { // Add separator only if there was a prefix message/name/template buf.WriteString(": ") } buf.WriteString(e.cause.Error()) } else if buf.Len() == 0 { // Handle case where msg/template/name are empty AND cause is nil // Could return a specific string like "[empty error]" or just "" return "" // Return empty string for a truly empty error } return buf.String() } // FastStack returns a lightweight stack trace with file and line numbers only. // Omits function names for performance; skips internal frames if configured. // Returns nil if no stack trace exists. // Example: // // for _, frame := range err.FastStack() { // fmt.Println(frame) // e.g., "main.go:42" // } func (e *Error) FastStack() []string { // Same len-vs-nil reasoning as Stack(). if len(e.stack) == 0 { return nil } configMu.RLock() filter := currentConfig.filterInternal configMu.RUnlock() pcs := e.stack frames := make([]string, 0, len(pcs)) for _, pc := range pcs { fn := runtime.FuncForPC(pc) if fn == nil { frames = append(frames, "unknown") continue } file, line := fn.FileLine(pc) if filter && isInternalFrame(runtime.Frame{File: file, Function: fn.Name()}) { continue } frames = append(frames, fmt.Sprintf("%s:%d", file, line)) } return frames } // Find searches the error chain for the first error where pred returns true. // Returns nil if no match is found or if pred is nil. // Example: // // err := err.Find(func(e error) bool { return strings.Contains(e.Error(), "timeout") }) func (e *Error) Find(pred func(error) bool) error { if e == nil || pred == nil { return nil } return Find(e, pred) } // Format returns a detailed, human-readable string representation of the error, // including message, code, context, stack, and cause. // Recursive for causes that are also *Error. // Example: // // fmt.Println(err.Format()) // // Output: // // Error: failed: cause // // Code: 500 // // Context: // // key: value // // Stack: // // 1. main.main main.go:42 func (e *Error) Format() string { var sb strings.Builder // Error message. sb.WriteString("Error: " + e.Error() + "\n") // Metadata. if e.code != 0 { sb.WriteString(fmt.Sprintf("Code: %d\n", e.code)) } // Context. if ctx := e.contextAtThisLevel(); len(ctx) > 0 { sb.WriteString("Context:\n") for k, v := range ctx { sb.WriteString(fmt.Sprintf("\t%s: %v\n", k, v)) } } // Stack trace. if e.stack != nil { sb.WriteString("Stack:\n") for i, frame := range e.Stack() { sb.WriteString(fmt.Sprintf("\t%d. %s\n", i+1, frame)) } } // Cause. if e.cause != nil { sb.WriteString("Caused by: ") if causeErr, ok := e.cause.(*Error); ok { sb.WriteString(causeErr.Format()) } else { sb.WriteString("Error: " + e.cause.Error() + "\n") } sb.WriteString("\n") } return sb.String() } // contextAtThisLevel returns context specific to this error, excluding inherited context. // Internal use by Format to isolate context per error level. func (e *Error) contextAtThisLevel() map[string]interface{} { if e.context == nil && e.smallCount == 0 { return nil } ctx := make(map[string]interface{}) // Add smallContext items. for i := 0; i < int(e.smallCount); i++ { ctx[e.smallContext[i].key] = e.smallContext[i].value } // Add map context items. if e.context != nil { for k, v := range e.context { ctx[k] = v } } return ctx } // Free resets the error and returns it to the pool if pooling is enabled. // Safe to call multiple times; no-op if pooling is disabled. // Call after use to return the error to the pool and prevent memory leaks. // Use defer err.Free() at the call site that created the error. // Example: // // defer err.Free() func (e *Error) Free() { if currentConfig.disablePooling { return } // Disarm any pending auto-cleanup (finalizer or runtime.AddCleanup) before // manually returning to the pool. Without this, GC could return the same // *Error a second time after Free() has already done so — double-put. errorPool.clearCleanup(e) e.Reset() if e.stack != nil { stackPool.Put(e.stack[:cap(e.stack)]) e.stack = nil } errorPool.Put(e) } // Has checks if the error contains meaningful content (message, template, name, or cause). // Returns false for nil or empty errors. // Example: // // if !err.Has() { // return nil // } func (e *Error) Has() bool { return e != nil && (e.msg != "" || e.template != "" || e.name != "" || e.cause != nil) } // HasContextKey checks if the specified key exists in the error’s context. // Thread-safe; checks both smallContext and map-based context. // Example: // // if err.HasContextKey("user_id") { // fmt.Println(err.Context()["user_id"]) // } func (e *Error) HasContextKey(key string) bool { e.mu.RLock() defer e.mu.RUnlock() if e.smallCount > 0 { for i := int32(0); i < e.smallCount; i++ { if e.smallContext[i].key == key { return true } } } if e.context != nil { _, exists := e.context[key] return exists } return false } // Increment atomically increases the error’s count by 1 and returns the error. // Useful for tracking repeated occurrences. // Example: // // err := err.Increment() func (e *Error) Increment() *Error { atomic.AddUint64(&e.count, 1) return e } // Is checks if the error matches the target by pointer, name, or cause chain. // Compatible with errors.Is; also matches by string for standard errors. // Returns true if the error or its cause matches the target. // Example: // // if errors.Is(err, errors.New("target")) { // handleTargetError() // } func (e *Error) Is(target error) bool { if e == nil || target == nil { return e == target } if e == target { return true } if e.name != "" { if te, ok := target.(*Error); ok && te.name != "" && e.name == te.name { return true } } // String-equality fallback: matches any error whose message equals this // error's message. This is intentional — it allows matching errors created // by fmt.Errorf or errors.New with the same text — but it deviates from // stdlib errors.Is which uses pointer/sentinel identity. // IMPORTANT: two distinct errors with identical messages will match each other. // For strict identity matching use errors.Const() to create named sentinels. if stdErr, ok := target.(error); ok && e.Error() == stdErr.Error() { return true } if e.cause != nil { return errors.Is(e.cause, target) } return false } // IsEmpty checks if the error lacks meaningful content (no message, name, template, or cause). // Returns true for nil or fully empty errors. // Example: // // if err.IsEmpty() { // return nil // } func (e *Error) IsEmpty() bool { if e == nil { return true } return e.msg == "" && e.template == "" && e.name == "" && e.cause == nil } // IsNull checks if the error is nil, empty, or contains only SQL NULL values in its context or cause. // Useful for handling database-related errors. // Example: // // if err.IsNull() { // return nil // } func (e *Error) IsNull() bool { if e == nil || e == emptyError { return true } // If no context or cause, and no content, it’s not null. if e.smallCount == 0 && e.context == nil && e.cause == nil { return false } // Check cause first. if e.cause != nil { var isNull bool if ce, ok := e.cause.(*Error); ok { isNull = ce.IsNull() } else { isNull = sqlNull(e.cause) } if isNull { return true } } // Check small context. if e.smallCount > 0 { allNull := true for i := 0; i < int(e.smallCount); i++ { isNull := sqlNull(e.smallContext[i].value) if !isNull { allNull = false break } } if !allNull { return false } } // Check regular context. if e.context != nil { allNull := true for _, v := range e.context { isNull := sqlNull(v) if !isNull { allNull = false break } } if !allNull { return false } } // Null if context exists and is all null. return e.smallCount > 0 || e.context != nil } // MarshalJSON serializes the error to JSON, including name, message, context, cause, stack, and code. // Causes are recursively serialized if they implement json.Marshaler or are *Error. // Example: // // data, _ := json.Marshal(err) // fmt.Println(string(data)) func (e *Error) MarshalJSON() ([]byte, error) { // Get buffer from pool. Do NOT defer-return it — we must copy the result // out of buf's backing array and return the buf to the pool BEFORE we return // the copied slice. If we defer the Put, another goroutine can Get the same // buf and overwrite its backing array while the caller is still reading our // returned slice (the race the detector flags). buf := jsonBufferPool.Get().(*bytes.Buffer) buf.Reset() // Create new encoder. enc := json.NewEncoder(buf) enc.SetEscapeHTML(false) // Prepare JSON structure. je := struct { Name string `json:"name,omitempty"` Message string `json:"message,omitempty"` Context map[string]interface{} `json:"context,omitempty"` Cause interface{} `json:"cause,omitempty"` Stack []string `json:"stack,omitempty"` Code int `json:"code,omitempty"` }{ Name: e.name, Message: e.msg, Code: e.Code(), } // Add context. if ctx := e.Context(); len(ctx) > 0 { je.Context = ctx } // Add stack. if e.stack != nil { je.Stack = e.Stack() } // Add cause. if e.cause != nil { switch c := e.cause.(type) { case *Error: je.Cause = c case json.Marshaler: je.Cause = c default: je.Cause = c.Error() } } // Encode JSON. if err := enc.Encode(je); err != nil { return nil, err } // Copy bytes out of buf before returning buf to the pool. // buf.Bytes() is a slice into buf's internal array — if we put buf back first // and another goroutine resets it, they share the same backing memory. raw := buf.Bytes() if len(raw) > 0 && raw[len(raw)-1] == '\n' { raw = raw[:len(raw)-1] } result := make([]byte, len(raw)) copy(result, raw) jsonBufferPool.Put(buf) return result, nil } // Msgf sets the error’s message using a formatted string and returns the error. // Overwrites any existing message. // Example: // // err := err.Msgf("user %s not found", username) func (e *Error) Msgf(format string, args ...interface{}) *Error { e.msg = fmt.Sprintf(format, args...) return e } // Name returns the error’s name, if set. // Example: // // if err.Name() == "AuthError" { // handleAuthError() // } func (e *Error) Name() string { return e.name } // Reset clears all fields of the error, preparing it for reuse in the pool. // Internal use by Free; does not release stack to stackPool. // Example: // // err.Reset() // Clear all fields. func (e *Error) Reset() { e.msg = "" e.name = "" e.template = "" e.category = "" e.code = 0 e.count = 0 e.cause = nil e.callback = nil e.formatWrapped = false if e.context != nil { for k := range e.context { delete(e.context, k) } } e.smallCount = 0 if e.stack != nil { e.stack = e.stack[:0] } } // Stack returns a detailed stack trace with function names, files, and line numbers. // Filters internal frames if configured; returns nil if no stack exists. // Example: // // for _, frame := range err.Stack() { // fmt.Println(frame) // e.g., "main.main main.go:42" // } func (e *Error) Stack() []string { // Use len check not nil: a recycled error has stack reset to stack[:0] // (non-nil, zero length). Calling CallersFrames on an empty slice returns // no frames, making Stack() silently return [] instead of nil. if len(e.stack) == 0 { return nil } frames := runtime.CallersFrames(e.stack) var trace []string for { frame, more := frames.Next() if frame == (runtime.Frame{}) { break } if currentConfig.filterInternal && isInternalFrame(frame) { continue } trace = append(trace, fmt.Sprintf("%s %s:%d", frame.Function, frame.File, frame.Line)) if !more { break } } return trace } // Trace ensures the error has a stack trace, capturing it if absent. // Returns the error for chaining. // Example: // // err := errors.New("failed").Trace() func (e *Error) Trace() *Error { // Check len rather than nil for the same reason as WithStack. if len(e.stack) == 0 { // skip=1: trimmed = skip+1 = 2, removes captureStack + Trace() itself. e.stack = captureStack(1) } return e } // Transform applies transformations to a copy of the error and returns the new error. // The original error is unchanged; nil-safe. // Example: // // newErr := err.Transform(func(e *Error) { e.With("key", "value") }) func (e *Error) Transform(fn func(*Error)) *Error { if e == nil || fn == nil { return e } newErr := e.Copy() fn(newErr) return newErr } // Unwrap returns the underlying cause of the error, if any. // Compatible with errors.Unwrap for chain traversal. // Example: // // cause := errors.Unwrap(err) func (e *Error) Unwrap() error { return e.cause } // UnwrapAll returns a slice of all errors in the chain, starting with this error. // Each error is isolated to prevent modifications affecting others. // Example: // // chain := err.UnwrapAll() // for _, e := range chain { // fmt.Println(e.Error()) // } func (e *Error) UnwrapAll() []error { if e == nil { return nil } // Return the original nodes directly. Each *Error in the chain already owns // its own context map and message — there is no bleeding between nodes. // Returning originals (rather than copies) ensures: // e.Error() returns only that node's own msg/name (cause is on the // NEXT node, not duplicated here — Error() appends e.cause.Error() which // is exactly the next node's contribution). // // Wait — Error() DOES append cause. So chain[0].Error() includes the full // chain. The test wants chain[0].Error() == "outer" (msg only). // return snapshot wrappers that expose only the node's own message. var chain []error current := error(e) for current != nil { if err, ok := current.(*Error); ok { // Wrap in a msgOnlyError so Error() returns only this node's own // message without appending the cause chain. Unwrap() still returns // the original *Error so standard chain traversal continues to work. chain = append(chain, &msgOnlyError{err: err}) } else { chain = append(chain, current) } if unwrapper, ok := current.(interface{ Unwrap() error }); ok { current = unwrapper.Unwrap() } else { break } } return chain } // Walk traverses the error chain, applying fn to each error. // Stops if fn is nil or the chain ends. // Example: // // err.Walk(func(e error) { fmt.Println(e.Error()) }) func (e *Error) Walk(fn func(error)) { if e == nil || fn == nil { return } current := error(e) for current != nil { fn(current) if unwrappable, ok := current.(interface{ Unwrap() error }); ok { current = unwrappable.Unwrap() } else { break } } } // With adds key-value pairs to the error's context and returns the error. // Uses a fixed-size array (smallContext) for up to contextSize items, then switches // to a map. Thread-safe. Accepts variadic key-value pairs. // Example: // // err := err.With("key1", value1, "key2", value2) func (e *Error) With(keyValues ...interface{}) *Error { if len(keyValues) == 0 { return e } // Validate that we have an even number of arguments if len(keyValues)%2 != 0 { keyValues = append(keyValues, "(MISSING)") } // Acquire the lock once up-front. The previous "optimistic read then lock" // pattern read e.smallCount and e.context without holding the lock, which // the race detector correctly flagged as a data race when two goroutines // call With() on the same *Error concurrently. e.mu.Lock() defer e.mu.Unlock() // Fast path: all pairs fit in the fixed-size smallContext array. if e.smallCount < contextSize && e.context == nil { remainingSlots := contextSize - int(e.smallCount) if len(keyValues)/2 <= remainingSlots { for i := 0; i < len(keyValues); i += 2 { key, ok := keyValues[i].(string) if !ok { key = fmt.Sprintf("%v", keyValues[i]) } e.smallContext[e.smallCount] = contextItem{key, keyValues[i+1]} e.smallCount++ } return e } } // Slow path: too many pairs or already using map context. // Initialize map context if needed if e.context == nil { e.context = make(map[string]interface{}, max(currentConfig.contextSize, len(keyValues)/2+int(e.smallCount))) // Migrate existing smallContext items for i := int32(0); i < e.smallCount; i++ { e.context[e.smallContext[i].key] = e.smallContext[i].value } // Reset smallCount since we've moved to map context e.smallCount = 0 } // Add all pairs to map context for i := 0; i < len(keyValues); i += 2 { key, ok := keyValues[i].(string) if !ok { key = fmt.Sprintf("%v", keyValues[i]) } e.context[key] = keyValues[i+1] } return e } // Helper function to get maximum of two integers func max(a, b int) int { if a > b { return a } return b } // WithCategory sets the error’s category and returns the error. // Example: // // err := err.WithCategory("validation") func (e *Error) WithCategory(category ErrorCategory) *Error { e.category = string(category) return e } // WithCode sets an HTTP-like status code and returns the error. // Example: // // err := err.WithCode(400) func (e *Error) WithCode(code int) *Error { e.code = int32(code) return e } // WithName sets the error’s name and returns the error. // Example: // // err := err.WithName("AuthError") func (e *Error) WithName(name string) *Error { e.name = name return e } // WithRetryable marks the error as retryable in its context and returns the error. // Example: // // err := err.WithRetryable() func (e *Error) WithRetryable() *Error { return e.With(ctxRetry, true) } // WithStack captures a stack trace if none exists and returns the error. // Skips one frame (caller of WithStack). // Example: // // err := errors.New("failed").WithStack() func (e *Error) WithStack() *Error { // Check len rather than nil: a pooled error has stack reset to stack[:0] // (non-nil but empty). The nil check would skip capture for recycled errors. if len(e.stack) == 0 { e.stack = captureStack(1) } return e } // WithTemplate sets a message template and returns the error. // Used as a fallback if the message is empty. // Example: // // err := err.WithTemplate("operation failed") func (e *Error) WithTemplate(template string) *Error { e.template = template return e } // WithTimeout marks the error as a timeout error in its context and returns the error. // Example: // // err := err.WithTimeout() func (e *Error) WithTimeout() *Error { return e.With(ctxTimeout, true) } // Wrap associates a cause error with this error, creating a chain. // Returns the error unchanged if cause is nil. // Example: // // err := errors.New("failed").Wrap(errors.New("cause")) func (e *Error) Wrap(cause error) *Error { if cause == nil { return e } e.cause = cause return e } // Wrapf wraps a cause error with formatted message and returns the error. // If cause is nil, returns the error unchanged. // Example: // // err := errors.New("base").Wrapf(io.EOF, "read failed: %s", "file.txt") func (e *Error) Wrapf(cause error, format string, args ...interface{}) *Error { e.msg = fmt.Sprintf(format, args...) if cause != nil { e.cause = cause } return e } // WrapNotNil wraps a cause error only if it is non-nil and returns the error. // Example: // // err := err.WrapNotNil(maybeError) func (e *Error) WrapNotNil(cause error) *Error { if cause != nil { e.cause = cause } return e } // LogValue implements slog.LogValuer so *Error can be passed directly to // any slog logging call and will be rendered as a structured group containing // message, name, code, category, and context fields. // // Example: // // slog.Error("request failed", "err", err) // // => err.message="...", err.name="AuthError", err.code=401, ... func (e *Error) LogValue() slog.Value { if e == nil { return slog.StringValue("") } attrs := make([]slog.Attr, 0, 6) if e.msg != "" { attrs = append(attrs, slog.String("message", e.msg)) } if e.name != "" { attrs = append(attrs, slog.String("name", e.name)) } if e.code != 0 { attrs = append(attrs, slog.Int("code", int(e.code))) } if e.category != "" { attrs = append(attrs, slog.String("category", e.category)) } if ctx := e.contextAtThisLevel(); len(ctx) > 0 { ctxAttrs := make([]slog.Attr, 0, len(ctx)) for k, v := range ctx { ctxAttrs = append(ctxAttrs, slog.Any(k, v)) } attrs = append(attrs, slog.Attr{Key: "context", Value: slog.GroupValue(ctxAttrs...)}) } if e.cause != nil { attrs = append(attrs, slog.String("cause", e.cause.Error())) } return slog.GroupValue(attrs...) } // msgOnlyError wraps a single *Error and returns only its own message from // Error(), without appending the cause chain. Used by UnwrapAll so each // element in the returned slice represents exactly one chain node. type msgOnlyError struct { err *Error } func (m *msgOnlyError) Error() string { if m.err.msg != "" { return m.err.msg } if m.err.name != "" { return m.err.name } if m.err.template != "" { return m.err.template } return "" } // Unwrap returns the underlying *Error so errors.Is/As and chain traversal work. func (m *msgOnlyError) Unwrap() error { return m.err.cause } // Convenience accessors so callers can still reach *Error fields after UnwrapAll. func (m *msgOnlyError) Name() string { return m.err.Name() } func (m *msgOnlyError) Code() int { return m.err.Code() } func (m *msgOnlyError) Context() map[string]interface{} { return m.err.Context() } func (m *msgOnlyError) Stack() []string { return m.err.Stack() } golang-github-olekukonko-errors-1.3.0/errors_benchmark_test.go000066400000000000000000000205111517267734700246710ustar00rootroot00000000000000package errors import ( "encoding/json" "errors" "fmt" "runtime" "testing" ) // Basic Error Creation Benchmarks // These benchmarks measure the performance of creating basic errors with and without // pooling, compared to standard library equivalents for baseline reference. // BenchmarkBasic_New measures the creation and pooling of a new error. func BenchmarkBasic_New(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { err := New("test error") // Create and pool a new error err.Free() } } // BenchmarkBasic_NewNoFree measures error creation without pooling. func BenchmarkBasic_NewNoFree(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { _ = New("test error") // Create error without returning to pool } } // BenchmarkBasic_StdlibComparison measures standard library error creation as a baseline. func BenchmarkBasic_StdlibComparison(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { _ = errors.New("test error") // Baseline using standard library errors.New } } // BenchmarkBasic_StdErrorComparison measures the package's Std wrapper for errors.New. func BenchmarkBasic_StdErrorComparison(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { _ = Std("test error") // Baseline using package’s Std wrapper for errors.New } } // BenchmarkBasic_StdfComparison measures the package's Stdf wrapper for fmt.Errorf. func BenchmarkBasic_StdfComparison(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { _ = Stdf("test error %d", i) // Baseline using package’s Stdf wrapper for fmt.Errorf } } // Stack Trace Benchmarks // These benchmarks evaluate the performance of stack trace operations, including // capturing and generating stack traces for error instances. // BenchmarkStack_WithStack measures adding a stack trace to an error. func BenchmarkStack_WithStack(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { err := New("test").WithStack() // Add stack trace to an error err.Free() } } // BenchmarkStack_Trace measures creating an error with a stack trace. func BenchmarkStack_Trace(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { err := Trace("test error") // Create error with stack trace err.Free() } } // BenchmarkStack_Capture measures generating a stack trace from an existing error. func BenchmarkStack_Capture(b *testing.B) { err := New("test") b.ResetTimer() for i := 0; i < b.N; i++ { _ = err.Stack() // Generate stack trace from existing error } err.Free() } // BenchmarkCaptureStack measures capturing a raw stack trace. func BenchmarkCaptureStack(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { stack := captureStack(0) // Capture raw stack trace if stack != nil { runtime.KeepAlive(stack) // Ensure stack isn’t optimized away } } } // Context Operation Benchmarks // These benchmarks assess the performance of adding context to errors, testing // small context (array-based), map-based, and reuse scenarios. // BenchmarkContext_Small measures adding context within the smallContext limit. func BenchmarkContext_Small(b *testing.B) { err := New("base") b.ResetTimer() for i := 0; i < b.N; i++ { _ = err.With("key", i).With("key2", i+1) // Add two key-value pairs within smallContext limit } err.Free() } // BenchmarkContext_Map measures adding context exceeding smallContext capacity. func BenchmarkContext_Map(b *testing.B) { err := New("base") b.ResetTimer() for i := 0; i < b.N; i++ { _ = err.With("k1", i).With("k2", i+1).With("k3", i+2) // Exceed smallContext, forcing map usage } err.Free() } // BenchmarkContext_Reuse measures adding to an existing context. func BenchmarkContext_Reuse(b *testing.B) { err := New("base").With("init", "value") b.ResetTimer() for i := 0; i < b.N; i++ { _ = err.With("key", i) // Add to existing context } err.Free() } // Error Wrapping Benchmarks // These benchmarks measure the cost of wrapping errors, both shallow and deep chains. // BenchmarkWrapping_Simple measures wrapping a single base error. func BenchmarkWrapping_Simple(b *testing.B) { base := New("base") b.ResetTimer() for i := 0; i < b.N; i++ { err := New("wrapper").Wrap(base) // Wrap a single base error err.Free() } base.Free() } // BenchmarkWrapping_Deep measures unwrapping a 10-level deep error chain. func BenchmarkWrapping_Deep(b *testing.B) { var err *Error for i := 0; i < 10; i++ { err = New("level").Wrap(err) // Build a 10-level deep error chain } b.ResetTimer() for i := 0; i < b.N; i++ { _ = err.Unwrap() // Unwrap the deep chain } err.Free() } // Type Assertion Benchmarks // These benchmarks evaluate the performance of type assertions (Is and As) on wrapped errors. // BenchmarkTypeAssertion_Is measures checking if an error matches a target. func BenchmarkTypeAssertion_Is(b *testing.B) { target := Named("target") err := New("wrapper").Wrap(target) b.ResetTimer() for i := 0; i < b.N; i++ { _ = Is(err, target) // Check if error matches target } target.Free() } // BenchmarkTypeAssertion_As measures extracting a target from an error chain. func BenchmarkTypeAssertion_As(b *testing.B) { err := New("wrapper").Wrap(Named("target")) var target *Error b.ResetTimer() for i := 0; i < b.N; i++ { _ = As(err, &target) // Extract target from error chain } if target != nil { target.Free() } } // Serialization Benchmarks // These benchmarks test JSON serialization performance with and without stack traces. // BenchmarkSerialization_JSON measures serializing an error with context to JSON. func BenchmarkSerialization_JSON(b *testing.B) { err := New("test").With("key", "value").With("num", 42) b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = json.Marshal(err) // Serialize error with context } } // BenchmarkSerialization_JSONWithStack measures serializing an error with stack trace to JSON. func BenchmarkSerialization_JSONWithStack(b *testing.B) { err := Trace("test").With("key", "value") b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = json.Marshal(err) // Serialize error with stack trace } } // Concurrency Benchmarks // These benchmarks assess performance under concurrent error creation and context modification. // BenchmarkConcurrency_Creation measures concurrent error creation and pooling. func BenchmarkConcurrency_Creation(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { err := New("parallel error") // Create errors concurrently err.Free() } }) } // BenchmarkConcurrency_Context measures concurrent context addition to a shared error. func BenchmarkConcurrency_Context(b *testing.B) { base := New("base") b.RunParallel(func(pb *testing.PB) { for pb.Next() { _ = base.With("key", "value") // Add context concurrently } }) base.Free() } // BenchmarkContext_Concurrent measures concurrent context addition with unique keys. func BenchmarkContext_Concurrent(b *testing.B) { err := New("base") b.RunParallel(func(pb *testing.PB) { i := 0 for pb.Next() { err.With(fmt.Sprintf("key%d", i%10), i) // Add unique keys concurrently i++ } }) } // Pool and Allocation Benchmarks // These benchmarks evaluate pooling mechanisms and raw allocation costs. // BenchmarkPoolGetPut measures the speed of pool get and put operations. func BenchmarkPoolGetPut(b *testing.B) { e := &Error{} b.ResetTimer() for i := 0; i < b.N; i++ { errorPool.Put(e) // Return error to pool e = errorPool.Get() // Retrieve error from pool } } // BenchmarkPoolWarmup measures the cost of resetting and warming the error pool. func BenchmarkPoolWarmup(b *testing.B) { for i := 0; i < b.N; i++ { errorPool = NewErrorPool() // Recreate pool WarmPool(100) // Pre-warm with 100 errors } } // BenchmarkStackAlloc measures the cost of allocating a stack slice. func BenchmarkStackAlloc(b *testing.B) { for i := 0; i < b.N; i++ { _ = make([]uintptr, 0, currentConfig.stackDepth) // Allocate stack slice } } // Special Case Benchmarks // These benchmarks test specialized error creation methods. // BenchmarkSpecial_Named measures creating a named error with a stack trace. func BenchmarkSpecial_Named(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { err := Named("test_error") // Create named error with stack err.Free() } } // BenchmarkSpecial_Format measures creating a formatted error. func BenchmarkSpecial_Format(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { err := Errorf("formatted %s %d", "error", i) // Create formatted error err.Free() } } golang-github-olekukonko-errors-1.3.0/errors_test.go000066400000000000000000001313311517267734700226620ustar00rootroot00000000000000// Package errors provides a robust error handling library with support for // error wrapping, stack traces, context storage, and retry mechanisms. // This test file verifies the correctness of the error type and its methods, // ensuring proper behavior for creation, wrapping, inspection, and serialization. // Tests cover edge cases, standard library compatibility, and thread-safety. package errors import ( "context" "database/sql" "encoding/json" "errors" "fmt" "reflect" "strings" "sync" "testing" "time" ) // customError is a test-specific error type for verifying error wrapping and traversal. type customError struct { msg string cause error } func (e *customError) Error() string { return e.msg } func (e *customError) Cause() error { return e.cause } // TestErrorNew verifies that New creates an error with the specified message // and does not capture a stack trace, ensuring lightweight error creation. func TestErrorNew(t *testing.T) { err := New("test error") defer err.Free() if err.Error() != "test error" { t.Errorf("New() error message = %v, want %v", err.Error(), "test error") } if len(err.Stack()) != 0 { t.Errorf("New() should not capture stack trace, got %d frames", len(err.Stack())) } } // TestErrorNewf checks that Newf formats the error message correctly using // the provided format string and arguments, without capturing a stack trace. func TestErrorNewf(t *testing.T) { err := Newf("test %s %d", "error", 42) defer err.Free() want := "test error 42" if err.Error() != want { t.Errorf("Newf() error message = %v, want %v", err.Error(), want) } if len(err.Stack()) != 0 { t.Errorf("Newf() should not capture stack trace, got %d frames", len(err.Stack())) } } // TestErrorNamed ensures that Named creates a named error with the given name // and captures a stack trace for debugging purposes. func TestErrorNamed(t *testing.T) { err := Named("test_name") defer err.Free() if err.Error() != "test_name" { t.Errorf("Named() error message = %v, want %v", err.Error(), "test_name") } if len(err.Stack()) == 0 { t.Errorf("Named() should capture stack trace") } } // TestErrorMethods tests the core methods of the Error type, including context // addition, wrapping, message formatting, stack tracing, and metadata handling. func TestErrorMethods(t *testing.T) { err := New("base error") defer err.Free() // Test With for adding key-value context. err = err.With("key", "value") if err.Context()["key"] != "value" { t.Errorf("With() failed, context[key] = %v, want %v", err.Context()["key"], "value") } // Test Wrap for setting a cause. cause := New("cause error") defer cause.Free() err = err.Wrap(cause) if err.Unwrap() != cause { t.Errorf("Wrap() failed, unwrapped = %v, want %v", err.Unwrap(), cause) } // Test Msgf for updating the error message. err = err.Msgf("new message %d", 123) if err.Error() != "new message 123: cause error" { t.Errorf("Msgf() failed, error = %v, want %v", err.Error(), "new message 123: cause error") } // Test stack absence initially. stackLen := len(err.Stack()) if stackLen != 0 { t.Errorf("Initial stack length should be 0, got %d", stackLen) } // Test Trace for capturing a stack trace. err = err.Trace() if len(err.Stack()) == 0 { t.Errorf("Trace() should capture a stack trace, got no frames") } // Test WithCode for setting an HTTP status code. err = err.WithCode(400) if err.Code() != 400 { t.Errorf("WithCode() failed, code = %d, want 400", err.Code()) } // Test WithCategory for setting a category. err = err.WithCategory("test_category") if Category(err) != "test_category" { t.Errorf("WithCategory() failed, category = %v, want %v", Category(err), "test_category") } // Test Increment for counting occurrences. err = err.Increment() if err.Count() != 1 { t.Errorf("Increment() failed, count = %d, want 1", err.Count()) } } // TestErrorIs verifies that Is correctly identifies errors by name or through // wrapping, including compatibility with standard library errors. func TestErrorIs(t *testing.T) { err := Named("test_error") defer err.Free() err2 := Named("test_error") defer err2.Free() err3 := Named("other_error") defer err3.Free() // Test matching same-named errors. if !err.Is(err2) { t.Errorf("Is() failed, %v should match %v", err, err2) } // Test non-matching names. if err.Is(err3) { t.Errorf("Is() failed, %v should not match %v", err, err3) } // Test wrapped error matching. wrappedErr := Named("wrapper") defer wrappedErr.Free() cause := Named("cause_error") defer cause.Free() wrappedErr = wrappedErr.Wrap(cause) if !wrappedErr.Is(cause) { t.Errorf("Is() failed, wrapped error should match cause; wrappedErr = %+v, cause = %+v", wrappedErr, cause) } // Test wrapping standard library error. stdErr := errors.New("std error") wrappedErr = wrappedErr.Wrap(stdErr) if !wrappedErr.Is(stdErr) { t.Errorf("Is() failed, should match stdlib error") } } // TestErrorAs checks that As unwraps to the correct error type, supporting // both custom *Error and standard library errors. func TestErrorAs(t *testing.T) { err := New("base").Wrap(Named("target")) defer err.Free() var target *Error if !As(err, &target) { t.Errorf("As() failed, should unwrap to *Error") } if target.name != "target" { t.Errorf("As() unwrapped to wrong error, got %v, want %v", target.name, "target") } stdErr := errors.New("std error") err = New("wrapper").Wrap(stdErr) defer err.Free() var stdTarget error if !As(err, &stdTarget) { t.Errorf("As() failed, should unwrap to stdlib error") } if stdTarget != stdErr { t.Errorf("As() unwrapped to wrong error, got %v, want %v", stdTarget, stdErr) } } // TestErrorCount verifies that Count tracks per-instance error occurrences. func TestErrorCount(t *testing.T) { err := New("unnamed") defer err.Free() if err.Count() != 0 { t.Errorf("Count() on new error should be 0, got %d", err.Count()) } err = Named("test_count").Increment() if err.Count() != 1 { t.Errorf("Count() after Increment() should be 1, got %d", err.Count()) } } // TestErrorCode ensures that Code correctly sets and retrieves HTTP status codes. func TestErrorCode(t *testing.T) { err := New("unnamed") defer err.Free() if err.Code() != 0 { t.Errorf("Code() on new error should be 0, got %d", err.Code()) } err = Named("test_code").WithCode(400) if err.Code() != 400 { t.Errorf("Code() after WithCode(400) should be 400, got %d", err.Code()) } } // TestErrorMarshalJSON verifies that JSON serialization includes all expected // fields: message, context, cause, code, and stack (when present). func TestErrorMarshalJSON(t *testing.T) { // Test basic error with context, code, and cause. err := New("test"). With("key", "value"). WithCode(400). Wrap(Named("cause")) defer err.Free() data, e := json.Marshal(err) if e != nil { t.Fatalf("MarshalJSON() failed: %v", e) } want := map[string]interface{}{ "message": "test", "context": map[string]interface{}{"key": "value"}, "cause": map[string]interface{}{"name": "cause"}, "code": float64(400), } var got map[string]interface{} if err := json.Unmarshal(data, &got); err != nil { t.Fatalf("Unmarshal failed: %v", err) } if got["message"] != want["message"] { t.Errorf("MarshalJSON() message = %v, want %v", got["message"], want["message"]) } if !reflect.DeepEqual(got["context"], want["context"]) { t.Errorf("MarshalJSON() context = %v, want %v", got["context"], want["context"]) } if cause, ok := got["cause"].(map[string]interface{}); !ok || cause["name"] != "cause" { t.Errorf("MarshalJSON() cause = %v, want %v", got["cause"], want["cause"]) } if code, ok := got["code"].(float64); !ok || code != 400 { t.Errorf("MarshalJSON() code = %v, want %v", got["code"], 400) } // Test error with stack trace. t.Run("WithStack", func(t *testing.T) { err := New("test").WithStack().WithCode(500) defer err.Free() data, e := json.Marshal(err) if e != nil { t.Fatalf("MarshalJSON() failed: %v", e) } var got map[string]interface{} if err := json.Unmarshal(data, &got); err != nil { t.Fatalf("Unmarshal failed: %v", err) } if _, ok := got["stack"].([]interface{}); !ok || len(got["stack"].([]interface{})) == 0 { t.Error("MarshalJSON() should include non-empty stack") } if code, ok := got["code"].(float64); !ok || code != 500 { t.Errorf("MarshalJSON() code = %v, want 500", got["code"]) } }) } // TestErrorEdgeCases verifies behavior for unusual inputs, such as nil errors, // empty names, and standard library error wrapping. func TestErrorEdgeCases(t *testing.T) { // Test nil error handling. var nilErr *Error if nilErr.Is(nil) { t.Errorf("nil.Is(nil) should be false, got true") } if Is(nilErr, New("test")) { t.Errorf("Is(nil, non-nil) should be false") } // Test empty name mismatch. err := New("empty name") defer err.Free() if err.Is(Named("")) { t.Errorf("Error with empty name should not match unnamed error") } // Test wrapping standard library error. stdErr := errors.New("std error") customErr := New("custom").Wrap(stdErr) defer customErr.Free() if !Is(customErr, stdErr) { t.Errorf("Is() should match stdlib error through wrapping") } // Test As with nil error. var nilTarget *Error if As(nilErr, &nilTarget) { t.Errorf("As(nil, &nilTarget) should return false") } // Additional edge case: Wrapping nil error. t.Run("WrapNil", func(t *testing.T) { err := New("wrapper").Wrap(nil) defer err.Free() if err.Unwrap() != nil { t.Errorf("Wrap(nil) should set cause to nil, got %v", err.Unwrap()) } if err.Error() != "wrapper" { t.Errorf("Wrap(nil) should preserve message, got %v, want %v", err.Error(), "wrapper") } }) } // TestErrorRetryWithCallback verifies the retry mechanism, ensuring the callback // is invoked correctly and retries exhaust as expected for retryable errors. func TestErrorRetryWithCallback(t *testing.T) { // Test retry with multiple attempts. attempts := 0 retry := NewRetry( WithMaxAttempts(3), WithDelay(1*time.Millisecond), WithOnRetry(func(attempt int, err error) { attempts++ }), ) err := retry.Execute(func() error { return New("retry me").WithRetryable() }) if attempts != 3 { t.Errorf("Expected 3 retry attempts, got %d", attempts) } if err == nil { t.Error("Expected retry to exhaust with error, got nil") } // Test zero max attempts, expecting one initial attempt (not a retry). t.Run("ZeroAttempts", func(t *testing.T) { attempts := 0 retry := NewRetry( WithMaxAttempts(0), WithOnRetry(func(attempt int, err error) { attempts++ }), ) err := retry.Execute(func() error { return New("retry me").WithRetryable() }) // Expect one attempt, as Execute runs the function once before checking retries. if attempts != 1 { t.Errorf("Expected 1 attempt (initial execution), got %d", attempts) } if err == nil { t.Error("Expected error, got nil") } }) } // TestErrorStackPresence confirms stack trace behavior for New and Trace methods. func TestErrorStackPresence(t *testing.T) { // New should not capture stack. err := New("test") if len(err.Stack()) != 0 { t.Error("New() should not capture stack") } // Trace should capture stack. traced := Trace("test") if len(traced.Stack()) == 0 { t.Error("Trace() should capture stack") } } // TestErrorStackDepth ensures that stack traces respect the configured maximum depth. func TestErrorStackDepth(t *testing.T) { err := Trace("test") frames := err.Stack() if len(frames) > currentConfig.stackDepth { t.Errorf("Stack depth %d exceeds configured max %d", len(frames), currentConfig.stackDepth) } } // TestErrorTransform verifies Transform behavior for nil, non-*Error, and *Error inputs. func TestErrorTransform(t *testing.T) { // Test nil input. t.Run("NilError", func(t *testing.T) { result := Transform(nil, func(e *Error) {}) if result != nil { t.Error("Should return nil for nil input") } }) // Test standard library error. t.Run("NonErrorType", func(t *testing.T) { stdErr := errors.New("standard") transformed := Transform(stdErr, func(e *Error) {}) if transformed == nil { t.Error("Should not return nil for non-nil input") } if transformed.Error() != "standard" { t.Errorf("Should preserve original message, got %q, want %q", transformed.Error(), "standard") } if transformed == stdErr { t.Error("Should return a new *Error, not the original") } }) // Test transforming *Error. t.Run("TransformError", func(t *testing.T) { orig := New("original") defer orig.Free() transformed := Transform(orig, func(e *Error) { e.With("key", "value") }) defer transformed.Free() if transformed == orig { t.Error("Should return a copy, not the original") } if transformed.Error() != "original" { t.Errorf("Should preserve original message, got %q, want %q", transformed.Error(), "original") } if transformed.Context()["key"] != "value" { t.Error("Should apply transformations, context missing 'key'='value'") } }) } // TestErrorWalk ensures Walk traverses the error chain correctly, visiting all errors. func TestErrorWalk(t *testing.T) { err1 := &customError{msg: "first error", cause: nil} err2 := &customError{msg: "second error", cause: err1} err3 := &customError{msg: "third error", cause: err2} var errorsWalked []string Walk(err3, func(e error) { errorsWalked = append(errorsWalked, e.Error()) }) expected := []string{"third error", "second error", "first error"} if !reflect.DeepEqual(errorsWalked, expected) { t.Errorf("Walk() = %v; want %v", errorsWalked, expected) } } // TestErrorFind verifies Find locates the first error matching the predicate. func TestErrorFind(t *testing.T) { err1 := &customError{msg: "first error", cause: nil} err2 := &customError{msg: "second error", cause: err1} err3 := &customError{msg: "third error", cause: err2} // Find existing error. found := Find(err3, func(e error) bool { return e.Error() == "second error" }) if found == nil || found.Error() != "second error" { t.Errorf("Find() = %v; want 'second error'", found) } // Find non-existent error. found = Find(err3, func(e error) bool { return e.Error() == "non-existent error" }) if found != nil { t.Errorf("Find() = %v; want nil", found) } } // TestErrorTraceStackContent checks that Trace captures meaningful stack frames. func TestErrorTraceStackContent(t *testing.T) { err := Trace("test") defer err.Free() frames := err.Stack() if len(frames) == 0 { t.Fatal("Trace() should capture stack frames") } found := false for _, frame := range frames { if strings.Contains(frame, "testing.tRunner") { found = true break } } if !found { t.Errorf("Trace() stack does not contain testing.tRunner, got: %v", frames) } } // TestErrorWithStackContent ensures WithStack captures meaningful stack frames. func TestErrorWithStackContent(t *testing.T) { err := New("test").WithStack() defer err.Free() frames := err.Stack() if len(frames) == 0 { t.Fatal("WithStack() should capture stack frames") } found := false for _, frame := range frames { if strings.Contains(frame, "testing.tRunner") { found = true break } } if !found { t.Errorf("WithStack() stack does not contain testing.tRunner, got: %v", frames) } } // TestErrorWrappingChain verifies a complex error chain with multiple layers, // ensuring correct message propagation, context isolation, and stack behavior. func TestErrorWrappingChain(t *testing.T) { databaseErr := New("connection timeout"). With("timeout_sec", 5). With("server", "db01.prod") defer databaseErr.Free() businessErr := New("failed to process user 12345"). With("user_id", "12345"). With("stage", "processing"). Wrap(databaseErr) defer businessErr.Free() apiErr := New("API request failed"). WithCode(500). WithStack(). Wrap(businessErr) defer apiErr.Free() // Verify full error message. expectedFullMessage := "API request failed: failed to process user 12345: connection timeout" if apiErr.Error() != expectedFullMessage { t.Errorf("Full error message mismatch\ngot: %q\nwant: %q", apiErr.Error(), expectedFullMessage) } // Verify error chain. chain := UnwrapAll(apiErr) if len(chain) != 3 { t.Fatalf("Expected chain length 3, got %d", len(chain)) } tests := []struct { index int expected string }{ {0, "API request failed"}, {1, "failed to process user 12345"}, {2, "connection timeout"}, } for _, tt := range tests { if chain[tt.index].Error() != tt.expected { t.Errorf("Chain position %d mismatch\ngot: %q\nwant: %q", tt.index, chain[tt.index].Error(), tt.expected) } } // Verify Is checks. if !errors.Is(apiErr, databaseErr) { t.Error("Is() should match the database error in the chain") } // Verify context isolation. if ctx := businessErr.Context(); ctx["timeout_sec"] != nil { t.Error("Business error should not have database context") } // Verify stack presence. if stack := apiErr.Stack(); len(stack) == 0 { t.Error("API error should have stack trace") } if stack := businessErr.Stack(); len(stack) != 0 { t.Error("Business error should not have stack trace") } // Verify code propagation. if apiErr.Code() != 500 { t.Error("API error should have code 500") } if businessErr.Code() != 0 { t.Error("Business error should have no code") } } // TestErrorExampleOutput verifies that formatted output includes all relevant // details, such as message, context, code, and stack, for a realistic error chain. func TestErrorExampleOutput(t *testing.T) { databaseErr := New("connection timeout"). With("timeout_sec", 5). With("server", "db01.prod") businessErr := New("failed to process user 12345"). With("user_id", "12345"). With("stage", "processing"). Wrap(databaseErr) apiErr := New("API request failed"). WithCode(500). WithStack(). Wrap(businessErr) chain := UnwrapAll(apiErr) for _, err := range chain { if e, ok := err.(*Error); ok { formatted := e.Format() if formatted == "" { t.Error("Format() returned empty string") } if !strings.Contains(formatted, "Error: "+e.Error()) { t.Errorf("Format() output missing error message: %q", formatted) } if e == apiErr { if !strings.Contains(formatted, "Code: 500") { t.Error("Format() missing code for API error") } if !strings.Contains(formatted, "Stack:") { t.Error("Format() missing stack for API error") } } if e == businessErr { if ctx := e.Context(); ctx != nil { if !strings.Contains(formatted, "Context:") { t.Error("Format() missing context for business error") } for k := range ctx { if !strings.Contains(formatted, k) { t.Errorf("Format() missing context key %q", k) } } } } } } if !errors.Is(apiErr, errors.New("connection timeout")) { t.Error("Is() failed to match connection timeout error") } } // TestErrorFullChain tests a complex chain with mixed error types (custom and standard), // verifying wrapping, unwrapping, and compatibility with standard library functions. func TestErrorFullChain(t *testing.T) { stdErr := errors.New("file not found") authErr := Named("AuthError").WithCode(401) storageErr := Wrapf(stdErr, "storage failed") authErrWrapped := Wrap(storageErr, authErr) wrapped := Wrapf(authErrWrapped, "request failed") var targetAuth *Error expectedTopLevelMsg := "request failed: AuthError: storage failed: file not found" if !errors.As(wrapped, &targetAuth) || targetAuth.Error() != expectedTopLevelMsg { t.Errorf("stderrors.As(wrapped, &targetAuth) failed, got %v, want %q", targetAuth.Error(), expectedTopLevelMsg) } var targetAuthPtr *Error if !As(wrapped, &targetAuthPtr) || targetAuthPtr.Name() != "AuthError" || targetAuthPtr.Code() != 401 { t.Errorf("As(wrapped, &targetAuthPtr) failed, got name=%s, code=%d; want AuthError, 401", targetAuthPtr.Name(), targetAuthPtr.Code()) } if !Is(wrapped, authErr) { t.Errorf("Is(wrapped, authErr) failed, expected true") } if !errors.Is(wrapped, authErr) { t.Errorf("stderrors.Is(wrapped, authErr) failed, expected true") } if !Is(wrapped, stdErr) { t.Errorf("Is(wrapped, stdErr) failed, expected true") } if !errors.Is(wrapped, stdErr) { t.Errorf("stderrors.Is(wrapped, stdErr) failed, expected true") } chain := UnwrapAll(wrapped) if len(chain) != 4 { t.Errorf("UnwrapAll(wrapped) length = %d, want 4", len(chain)) } expected := []string{ "request failed", "AuthError", "storage failed", "file not found", } for i, err := range chain { if err.Error() != expected[i] { t.Errorf("UnwrapAll[%d] = %v, want %v", i, err.Error(), expected[i]) } } } // TestErrorUnwrapAllMessageIsolation ensures UnwrapAll preserves individual error messages. func TestErrorUnwrapAllMessageIsolation(t *testing.T) { inner := New("inner") middle := New("middle").Wrap(inner) outer := New("outer").Wrap(middle) chain := UnwrapAll(outer) if chain[0].Error() != "outer" { t.Errorf("Expected 'outer', got %q", chain[0].Error()) } if chain[1].Error() != "middle" { t.Errorf("Expected 'middle', got %q", chain[1].Error()) } if chain[2].Error() != "inner" { t.Errorf("Expected 'inner', got %q", chain[2].Error()) } } // TestErrorIsEmpty verifies IsEmpty behavior for various error states, including // nil, empty messages, and errors with causes or templates. func TestErrorIsEmpty(t *testing.T) { tests := []struct { name string err *Error expected bool }{ {"nil error", nil, true}, {"empty error", New(""), true}, {"named empty", Named(""), true}, {"with empty template", New("").WithTemplate(""), true}, {"with message", New("test"), false}, {"with name", Named("test"), false}, {"with template", New("").WithTemplate("template"), false}, {"with cause", New("").Wrap(New("cause")), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.err != nil { defer tt.err.Free() } if got := tt.err.IsEmpty(); got != tt.expected { t.Errorf("IsEmpty() = %v, want %v", got, tt.expected) } }) } } // TestErrorIsNull verifies IsNull behavior for null and non-null errors, including // SQL null values in context or causes. func TestErrorIsNull(t *testing.T) { nullString := sql.NullString{Valid: false} validString := sql.NullString{String: "test", Valid: true} tests := []struct { name string err *Error expected bool }{ {"nil error", nil, true}, {"empty error", New(""), false}, {"with NULL context", New("").With("data", nullString), true}, {"with valid context", New("").With("data", validString), false}, {"with NULL cause", New("").Wrap(New("NULL value").With("data", nullString)), true}, {"with valid cause", New("").Wrap(New("valid value").With("data", validString)), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.err != nil { defer tt.err.Free() } if got := tt.err.IsNull(); got != tt.expected { t.Errorf("IsNull() = %v, want %v", got, tt.expected) } }) } } // TestErrorFromContext ensures FromContext enhances errors with context information, // such as deadlines and cancellations. func TestErrorFromContext(t *testing.T) { // Test nil error. t.Run("nil error returns nil", func(t *testing.T) { ctx := context.Background() if FromContext(ctx, nil) != nil { t.Error("Expected nil for nil input error") } }) // Test deadline exceeded. t.Run("deadline exceeded", func(t *testing.T) { deadline := time.Now().Add(-1 * time.Hour) ctx, cancel := context.WithDeadline(context.Background(), deadline) defer cancel() err := errors.New("operation failed") cerr := FromContext(ctx, err) if !IsTimeout(cerr) { t.Error("Expected timeout error") } if !HasContextKey(cerr, "deadline") { t.Error("Expected deadline in context") } }) // Test cancelled context. t.Run("cancelled context", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() err := errors.New("operation failed") cerr := FromContext(ctx, err) if !HasContextKey(cerr, "cancelled") { t.Error("Expected cancelled flag") } }) } // TestContextStorage verifies the smallContext optimization and its expansion // to a full map, including thread-safety under concurrent access. func TestContextStorage(t *testing.T) { // Test smallContext for first 4 items. t.Run("stores first 4 items in smallContext", func(t *testing.T) { Configure(Config{DisablePooling: true}) err := New("test") err.With("a", 1) err.With("b", 2) err.With("c", 3) err.With("d", 4) if err.smallCount != 4 { t.Errorf("expected smallCount=4, got %d", err.smallCount) } if err.context != nil { t.Error("expected context map to be nil") } }) // Test expansion to map on 5th item. t.Run("switches to map on 5th item", func(t *testing.T) { Configure(Config{DisablePooling: true}) err := New("test") err.With("a", 1) err.With("b", 2) err.With("c", 3) err.With("d", 4) err.With("e", 5) if err.context == nil { t.Error("expected context map to be initialized") } if len(err.context) != 5 { t.Errorf("expected 5 items in map, got %d", len(err.context)) } }) // Test preservation of all context items. t.Run("preserves all context items", func(t *testing.T) { err := New("test") items := []struct { k string v interface{} }{ {"a", 1}, {"b", 2}, {"c", 3}, {"d", 4}, {"e", 5}, {"f", 6}, } for _, item := range items { err.With(item.k, item.v) } ctx := err.Context() if len(ctx) != len(items) { t.Errorf("expected %d items, got %d", len(items), len(ctx)) } for _, item := range items { if val, ok := ctx[item.k]; !ok || val != item.v { t.Errorf("missing item %s in context", item.k) } } }) // Test concurrent access safety. t.Run("concurrent access", func(t *testing.T) { Configure(Config{DisablePooling: true}) err := New("test") var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() err.With("a", 1) err.With("b", 2) err.With("c", 3) }() go func() { defer wg.Done() err.With("d", 4) err.With("e", 5) err.With("f", 6) }() wg.Wait() ctx := err.Context() if len(ctx) != 6 { t.Errorf("expected 6 items, got %d", len(ctx)) } }) } // TestNewf verifies Newf behavior, including %w wrapping, formatting, and error cases. // TestNewf verifies Newf behavior, including %w wrapping, formatting, and error cases. // It now expects the string output for %w cases to match fmt.Errorf. func TestNewf(t *testing.T) { // Reusable error instances for testing %w stdErrorInstance := errors.New("std error") customErrorInstance := New("custom error") // Assuming this exists in your tests firstErrorInstance := New("first") secondErrorInstance := New("second") tests := []struct { name string format string args []interface{} wantFinalMsg string // EXPECTATION UPDATED TO MATCH fmt.Errorf wantInternalMsg string // This field might be less relevant now, maybe remove? Kept for reference. wantCause error wantErrFormat bool // Indicates if Newf itself should return a format error message }{ // Basic formatting (no change needed) { name: "simple string", format: "simple %s", args: []interface{}{"test"}, wantFinalMsg: "simple test", wantInternalMsg: "simple test", // Stays same as FinalMsg when no %w }, { name: "complex format without %w", format: "code=%d msg=%s", args: []interface{}{123, "hello"}, wantFinalMsg: "code=123 msg=hello", wantInternalMsg: "code=123 msg=hello", }, { name: "empty format no args", format: "", args: []interface{}{}, wantFinalMsg: "", wantInternalMsg: "", }, // %w wrapping cases (EXPECTATIONS UPDATED) { name: "wrap standard error", format: "prefix %w", args: []interface{}{stdErrorInstance}, wantFinalMsg: "prefix std error", // Matches fmt.Errorf output wantInternalMsg: "prefix std error", // Now wantInternalMsg matches FinalMsg for %w wantCause: stdErrorInstance, }, { name: "wrap custom error", format: "prefix %w", args: []interface{}{customErrorInstance}, wantFinalMsg: "prefix custom error", // Matches fmt.Errorf output wantInternalMsg: "prefix custom error", wantCause: customErrorInstance, }, { name: "%w at start", format: "%w suffix", args: []interface{}{stdErrorInstance}, wantFinalMsg: "std error suffix", // Matches fmt.Errorf output wantInternalMsg: "std error suffix", wantCause: stdErrorInstance, }, { name: "%w with flags (flags ignored by %w)", format: "prefix %+w suffix", // fmt.Errorf ignores flags like '+' for %w args: []interface{}{stdErrorInstance}, wantFinalMsg: "prefix std error suffix", // Matches fmt.Errorf output wantInternalMsg: "prefix std error suffix", wantCause: stdErrorInstance, }, { name: "no space around %w", format: "prefix%wsuffix", args: []interface{}{stdErrorInstance}, wantFinalMsg: "prefixstd errorsuffix", // Matches fmt.Errorf output wantInternalMsg: "prefixstd errorsuffix", wantCause: stdErrorInstance, }, { name: "format becomes empty after removing %w", format: "%w", args: []interface{}{stdErrorInstance}, wantFinalMsg: "std error", // Matches fmt.Errorf output wantInternalMsg: "std error", wantCause: stdErrorInstance, }, // Error cases (no change needed in expectations, as these test Newf's error messages) { name: "multiple %w", format: "%w %w", args: []interface{}{firstErrorInstance, secondErrorInstance}, wantFinalMsg: `errors.Newf: format "%w %w" has multiple %w verbs`, wantInternalMsg: `errors.Newf: format "%w %w" has multiple %w verbs`, wantErrFormat: true, }, { name: "no args for %w", format: "prefix %w", args: []interface{}{}, wantFinalMsg: `errors.Newf: format "prefix %w" has %w but not enough arguments`, wantInternalMsg: `errors.Newf: format "prefix %w" has %w but not enough arguments`, wantErrFormat: true, }, { name: "non-error for %w", format: "prefix %w", args: []interface{}{"not an error"}, wantFinalMsg: `errors.Newf: argument 0 for %w is not a non-nil error (string)`, wantInternalMsg: `errors.Newf: argument 0 for %w is not a non-nil error (string)`, wantErrFormat: true, }, { name: "nil error for %w", format: "prefix %w", args: []interface{}{error(nil)}, wantFinalMsg: `errors.Newf: argument 0 for %w is not a non-nil error ()`, wantInternalMsg: `errors.Newf: argument 0 for %w is not a non-nil error ()`, wantErrFormat: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Need to ensure pooled errors are freed if they are used in args // Safest is often to recreate them inside the test run if pooling is enabled // For simplicity here, assuming they are managed correctly or pooling is off // If customErrorInstance is pooled, it needs defer Free() or similar management. got := Newf(tt.format, tt.args...) if got == nil { t.Fatalf("Newf() returned nil, expected *Error") } // Consider defer got.Free() if AutoFree is false in config if gotMsg := got.Error(); gotMsg != tt.wantFinalMsg { t.Errorf("Newf().Error() = %q, want %q", gotMsg, tt.wantFinalMsg) } // Cause verification remains crucial gotCause := errors.Unwrap(got) if tt.wantCause != nil { // Use errors.Is for robust checking, especially if causes might be wrapped themselves if gotCause == nil { t.Errorf("Newf() cause = nil, want %v (%T)", tt.wantCause, tt.wantCause) } else if !errors.Is(got, tt.wantCause) { // Check the chain t.Errorf("Newf() cause mismatch (using Is): got chain does not contain %v (%T)", tt.wantCause, tt.wantCause) } else if gotCause != tt.wantCause { // Optional: Also check direct cause equality if important // t.Logf("Note: Unwrap() direct cause = %v (%T), expected %v (%T)", gotCause, gotCause, tt.wantCause, tt.wantCause) } } else { // Expected no cause if gotCause != nil { t.Errorf("Newf() cause = %v (%T), want nil", gotCause, gotCause) } } // If we expected a format error, the cause should definitely be nil if tt.wantErrFormat && gotCause != nil { t.Errorf("Newf() returned format error %q but unexpectedly set cause to %v", got.Error(), gotCause) } // Check internal message field if still relevant (might remove this check) // if !tt.wantErrFormat && got.msg != tt.wantInternalMsg { // t.Errorf("Newf().msg internal field = %q, want %q", got.msg, tt.wantInternalMsg) // } }) } } // TestNewfCompatibilityWithFmtErrorf compares the functional behavior of this library's // Newf function (when using the %w verb) with the standard library's fmt.Errorf. // // Rationale for using compareWrappedErrorStrings helper: // // Goal: Ensure essential compatibility - correct error wrapping (for Unwrap/Is/As) // // and preservation of the message content surrounding the wrapped error. // // Formatting Difference: This library consistently formats wrapped errors in its // // Error() method as "MESSAGE: CAUSE_ERROR" (or just "CAUSE_ERROR" if MESSAGE is empty). // The standard fmt.Errorf has more complex and variable spacing rules depending on // characters around %w (e.g., sometimes omitting the colon, adding spaces differently). // // Semantic Comparison: Attempting to replicate fmt.Errorf's exact spacing makes the // // library code brittle and overly complex. Therefore, this test focuses on *semantic* // equivalence rather than exact string matching. // // Helper Logic: compareWrappedErrorStrings verifies compatibility by: // // a) Checking that errors.Unwrap returns the same underlying cause instance. // b) Extracting the textual prefix from this library's error string (before ": CAUSE"). // c) Extracting the textual remainder from fmt.Errorf's string by removing the cause string. // d) Normalizing both extracted parts (trimming space, collapsing internal whitespace). // e) Comparing the normalized parts to ensure the core message content matches. // // This approach ensures functional compatibility without being overly sensitive to minor // formatting variations between the libraries. func TestNewfCompatibilityWithFmtErrorf(t *testing.T) { tests := []struct { name string format string argsFn func() []interface{} // Fresh args for each run }{ {"simple %w", "simple %w", func() []interface{} { return []interface{}{errors.New("error")} }}, {"complex %s %d %w", "complex %s %d %w", func() []interface{} { return []interface{}{"test", 42, errors.New("error")} }}, {"no space %w next", "no space %w next", func() []interface{} { return []interface{}{errors.New("error")} }}, {"%w starts", "%w starts", func() []interface{} { return []interface{}{errors.New("error")} }}, {"format is only %w", "%w", func() []interface{} { return []interface{}{errors.New("error")} }}, {"%w with flags", "%+w suffix", func() []interface{} { return []interface{}{errors.New("error")} }}, // fmt.Errorf ignores flags } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { args := tt.argsFn() var causeErrArg error // Find the error argument used for %w for _, arg := range args { if e, ok := arg.(error); ok { causeErrArg = e break // Assume the first error found is the one for %w } } if causeErrArg == nil { t.Fatalf("Test setup error: Could not find error argument for %%w in args: %v", args) } // Generate errors using both libraries stdErr := fmt.Errorf(tt.format, args...) customErrImpl := Newf(tt.format, args...) if customErrImpl == nil { t.Fatalf("Newf returned nil unexpectedly") } // Consider defer customErrImpl.Free() if needed // Verify Cause stdUnwrapped := errors.Unwrap(stdErr) customUnwrapped := errors.Unwrap(customErrImpl) if stdUnwrapped == nil || customUnwrapped == nil { t.Errorf("Expected both errors to be unwrappable, stdUnwrap=%v, customUnwrap=%v", stdUnwrapped, customUnwrapped) } else { // Check if the unwrapped errors are the *same instance* we passed in if customUnwrapped != causeErrArg { t.Errorf("Custom error did not unwrap to the original cause instance.\n got: %p (%T)\n want: %p (%T)", customUnwrapped, customUnwrapped, causeErrArg, causeErrArg) } if stdUnwrapped != causeErrArg { // This check is more about validating the test itself t.Logf("Standard error did not unwrap to the original cause instance (test validation).\n got: %p (%T)\n want: %p (%T)", stdUnwrapped, stdUnwrapped, causeErrArg, causeErrArg) } // Verify errors.Is works correctly on the custom error if !errors.Is(customErrImpl, causeErrArg) { t.Errorf("errors.Is(customErrImpl, causeErrArg) failed") } } // Verify String Output (Exact Match) gotStr := customErrImpl.Error() wantStr := stdErr.Error() if gotStr != wantStr { t.Errorf("String output mismatch:\n got: %q\nwant: %q", gotStr, wantStr) } }) } } var errForEdgeCases = errors.New("error") // TestNewfEdgeCases covers additional Newf scenarios, such as nil interfaces, // escaped percent signs, and malformed formats. // Expectations for %w cases are updated for fmt.Errorf compatibility. func TestNewfEdgeCases(t *testing.T) { tests := []struct { name string format string args []interface{} wantMsg string // EXPECTATION UPDATED wantCause error }{ // Cases without %w (no change) { name: "nil interface arg for %v", format: "test %v", args: []interface{}{interface{}(nil)}, wantMsg: "test ", }, { name: "malformed format ends with %", format: "test %w %", // This case causes a parse error, not a %w formatting issue args: []interface{}{errForEdgeCases}, wantMsg: `errors.Newf: format "test %w %" ends with %`, // Newf's specific error message wantCause: nil, }, // Cases with %w (EXPECTATIONS UPDATED) { name: "escaped %% with %w", format: "%%prefix %% %w %%suffix", args: []interface{}{errForEdgeCases}, wantMsg: "%prefix % error %suffix", // Matches fmt.Errorf output wantCause: errForEdgeCases, }, { name: "multiple verbs before %w", format: "%s %d %w", args: []interface{}{"foo", 42, errForEdgeCases}, wantMsg: "foo 42 error", // Matches fmt.Errorf output wantCause: errForEdgeCases, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := Newf(tt.format, tt.args...) if err == nil { t.Fatalf("Newf returned nil") } if gotMsg := err.Error(); gotMsg != tt.wantMsg { t.Errorf("Newf().Error() = %q, want %q", gotMsg, tt.wantMsg) } // Cause verification gotCause := errors.Unwrap(err) if tt.wantCause != nil { if !errors.Is(err, tt.wantCause) { t.Errorf("errors.Is(err, wantCause) failed.\n err: [%T: %q]\n wantCause: [%T: %q]\n gotCause (Unwrap): [%T: %v]", err, err, tt.wantCause, tt.wantCause, gotCause, gotCause) } } else { if gotCause != nil { t.Errorf("Newf() cause = [%T: %v], want nil", gotCause, gotCause) } } }) } } // compareWrappedErrorStrings verifies semantic equivalence between custom and // standard library error messages, normalizing spacing differences. func compareWrappedErrorStrings(t *testing.T, customStr, stdStr, causeStr string) { t.Helper() var customPrefix string if strings.HasSuffix(customStr, ": "+causeStr) { customPrefix = strings.TrimSuffix(customStr, ": "+causeStr) } else if customStr == causeStr { customPrefix = "" } else { t.Logf("Unexpected custom error string structure: %q for cause %q", customStr, causeStr) customPrefix = customStr } stdRemainder := strings.Replace(stdStr, causeStr, "", 1) normCustomPrefix := strings.TrimSpace(spaceRe.ReplaceAllString(customPrefix, " ")) normStdRemainder := strings.TrimSpace(spaceRe.ReplaceAllString(stdRemainder, " ")) if normCustomPrefix != normStdRemainder { t.Errorf("Semantic content mismatch (excluding cause):\n custom prefix: %q (from %q)\n std remainder: %q (from %q)", normCustomPrefix, customStr, normStdRemainder, stdStr) } } func TestWithVariadic(t *testing.T) { t.Run("single key-value", func(t *testing.T) { err := New("test").With("key1", "value1") if val, ok := err.Context()["key1"]; !ok || val != "value1" { t.Errorf("Expected key1=value1, got %v", val) } }) t.Run("multiple key-values", func(t *testing.T) { err := New("test").With("key1", 1, "key2", 2, "key3", 3) ctx := err.Context() if ctx["key1"] != 1 || ctx["key2"] != 2 || ctx["key3"] != 3 { t.Errorf("Expected all keys to be set, got %v", ctx) } }) t.Run("odd number of args", func(t *testing.T) { err := New("test").With("key1", 1, "key2") ctx := err.Context() if ctx["key1"] != 1 || ctx["key2"] != "(MISSING)" { t.Errorf("Expected key1=1 and key2=(MISSING), got %v", ctx) } }) t.Run("non-string keys", func(t *testing.T) { err := New("test").With(123, "value1", true, "value2") ctx := err.Context() if ctx["123"] != "value1" || ctx["true"] != "value2" { t.Errorf("Expected converted keys, got %v", ctx) } }) t.Run("transition to map context", func(t *testing.T) { // Assuming contextSize is 4 err := New("test"). With("k1", 1, "k2", 2, "k3", 3, "k4", 4). // fills smallContext With("k5", 5) // should trigger map transition if err.smallCount != 0 { t.Error("Expected smallCount to be 0 after transition") } if len(err.context) != 5 { t.Error("Expected all 5 items in map context") } }) t.Run("concurrent access", func(t *testing.T) { err := New("test") var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() err.With("key1", 1, "key2", 2) }() go func() { defer wg.Done() err.With("key3", 3, "key4", 4) }() wg.Wait() ctx := err.Context() if len(ctx) != 4 { t.Errorf("Expected 4 items in context, got %d", len(ctx)) } }) t.Run("mixed existing context", func(t *testing.T) { err := New("test"). With("k1", 1). // smallContext With("k2", 2, "k3", 3, "k4", 4, "k5", 5) // some in small, some in map if len(err.context) != 5 { t.Errorf("Expected 5 items total, got %d", len(err.context)) } }) t.Run("large number of pairs", func(t *testing.T) { err := New("test") args := make([]interface{}, 20) for i := 0; i < 10; i++ { args[i*2] = i args[i*2+1] = i * 10 } err = err.With(args...) ctx := err.Context() if len(ctx) != 10 { t.Errorf("Expected 10 items, got %d", len(ctx)) } if ctx["5"] != 50 { t.Errorf("Expected ctx[5]=50, got %v", ctx["5"]) } }) } func TestWrapf(t *testing.T) { t.Run("basic wrapf", func(t *testing.T) { cause := New("cause") err := New("wrapper").Wrapf(cause, "formatted %s", "message") if err.Unwrap() != cause { t.Error("Unwrap() should return the cause") } if err.Error() != "formatted message: cause" { t.Errorf("Expected 'formatted message: cause', got '%s'", err.Error()) } }) t.Run("nil cause", func(t *testing.T) { err := New("wrapper").Wrapf(nil, "format %s", "test") if err.Unwrap() != nil { t.Error("Unwrap() should return nil for nil cause") } if err.Error() != "format test" { t.Errorf("Expected 'format test', got '%s'", err.Error()) } }) t.Run("complex formatting", func(t *testing.T) { cause := New("cause") err := New("wrapper").Wrapf(cause, "value: %d, str: %s", 42, "hello") if err.Error() != "value: 42, str: hello: cause" { t.Errorf("Expected complex formatting, got '%s'", err.Error()) } }) t.Run("wrapf with std error", func(t *testing.T) { stdErr := errors.New("io error") err := New("wrapper").Wrapf(stdErr, "operation failed after %d attempts", 3) if err.Unwrap() != stdErr { t.Error("Should be able to wrap standard errors with Wrapf") } if err.Error() != "operation failed after 3 attempts: io error" { t.Errorf("Expected formatted message with cause, got '%s'", err.Error()) } }) t.Run("preserves other fields", func(t *testing.T) { cause := New("cause").WithCode(404) err := New("wrapper"). With("key", "value"). WithCode(500). Wrapf(cause, "formatted") if err.Code() != 500 { t.Error("Wrapf should preserve error code") } if val, ok := err.Context()["key"]; !ok || val != "value" { t.Error("Wrapf should preserve context") } if err.Unwrap().(*Error).Code() != 404 { t.Error("Should preserve cause's code") } }) } func TestWrapping(t *testing.T) { cause := New("root cause") err := Err("wrap it", cause) if err == nil { t.Fatal("expected non-nil error") } if !Is(err, cause) { t.Fatal("wrapping failed") } if err.Error() != "wrap it: root cause" { t.Fatalf("wrong message: %q", err.Error()) } err = Newf("wrap it: %w", cause) if err == nil { t.Fatal("expected non-nil error") } if !Is(err, cause) { t.Fatal("wrapping failed") } if err.Error() != "wrap it: root cause" { t.Fatalf("wrong message: %q", err.Error()) } } golang-github-olekukonko-errors-1.3.0/generic.go000066400000000000000000000070601517267734700217240ustar00rootroot00000000000000package errors // AsType attempts to find the first error in the chain that matches type T. // Returns the matched error and true if found, otherwise zero value and false. func AsType[T error](err error) (T, bool) { var target T if As(err, &target) { // Uses errors.As from helper.go return target, true } var zero T return zero, false } // IsType checks if the error or any error in its chain is of type T. func IsType[T error](err error) bool { var target T return As(err, &target) // Uses errors.As from helper.go } // FindType returns the first error in the chain of type T that satisfies the predicate. func FindType[T error](err error, predicate func(T) bool) (T, bool) { var zero T if err == nil || predicate == nil { return zero, false } // Use your Walk function to traverse the chain var found T var foundIt bool Walk(err, func(e error) { if !foundIt { if target, ok := e.(T); ok && predicate(target) { found = target foundIt = true } } }) return found, foundIt } // Map applies a transformation function to each error of type T in the chain. func Map[T error, R any](err error, fn func(T) R) []R { var results []R if err == nil || fn == nil { return results } Walk(err, func(e error) { if target, ok := e.(T); ok { results = append(results, fn(target)) } }) return results } // Reduce walks the error chain and accumulates a result for errors of type T. func Reduce[T error, R any](err error, initial R, fn func(T, R) R) R { result := initial if err == nil || fn == nil { return result } Walk(err, func(e error) { if target, ok := e.(T); ok { result = fn(target, result) } }) return result } // Filter returns a slice of all errors of type T from the error chain. func Filter[T error](err error) []T { var results []T if err == nil { return results } Walk(err, func(e error) { if target, ok := e.(T); ok { results = append(results, target) } }) return results } // FirstOfType returns the first error in the chain of type T. func FirstOfType[T error](err error) (T, bool) { var zero T if err == nil { return zero, false } var found T var foundIt bool Walk(err, func(e error) { if !foundIt { if target, ok := e.(T); ok { found = target foundIt = true } } }) return found, foundIt } // Contains checks if any error in the chain matches any of the target errors. // Uses our package's Is function for matching. func Contains(err error, targets ...error) bool { for _, target := range targets { if Is(err, target) { // Uses errors.Is from helper.go return true } } return false } // JoinErrors joins multiple errors using Join and wraps them with context. // Returns nil if all errors are nil. func JoinErrors(errs []error, keyValues ...interface{}) error { nonNil := make([]error, 0, len(errs)) for _, err := range errs { if err != nil { nonNil = append(nonNil, err) } } if len(nonNil) == 0 { return nil } if len(nonNil) == 1 && len(keyValues) == 0 { return nonNil[0] } joined := Join(nonNil...) if len(keyValues) == 0 { return joined } // When multiple errors are joined, the test asserts the result is *MultiError. // Attach context to the *MultiError's first underlying *Error if possible, // otherwise return the MultiError directly (context check in tests is guarded // by a got.(*Error) type assertion and is a no-op for *MultiError). if _, ok := joined.(*MultiError); ok { return joined } // Single error with context: wrap in *Error so context is accessible. e := New("multiple errors occurred") e.Wrap(joined) e.With(keyValues...) return e } golang-github-olekukonko-errors-1.3.0/generic_test.go000066400000000000000000000424301517267734700227630ustar00rootroot00000000000000package errors import ( "errors" "testing" ) // baseErr is a local type alias for *Error used solely to give the anonymous // embedded field a name that is not "Error". This is the only way to embed // *Error and still satisfy the error interface: // // - Anonymous embed of *Error -> field name "Error" -> shadows Error() method // - Named field E *Error -> not an embed -> methods not promoted // - Anonymous embed of *baseErr -> field name "baseErr" -> no shadowing // where baseErr = Error (alias) -> Error() and Wrap() promoted correctly type baseErr = Error // type alias: same type, different identifier at embed site // testError is a custom error type for testing generics. type testError struct { *baseErr // anonymous embed; field name is "baseErr", not "Error" code int } func newTestError(code int) *testError { return &testError{ baseErr: New("test error").WithCode(code), code: code, } } // Wrap sets cause on the embedded *Error and returns *testError so callers can // chain further calls on *testError without losing the concrete type. // Wrap sets cause as the cause of this error and returns e so callers keep the // concrete *testError type across chains. It mutates e in place (same semantics // as *Error.Wrap) — each call replaces the cause, so // // newTestError(400).Wrap(A).Wrap(B) // // produces: testError(400, cause=B). That matches *Error.Wrap behaviour. // For a chain of three testErrors, callers must nest: A.Wrap(B.Wrap(C)) or the // test uses the chained form which is fine because each .Wrap() call is on the // result of the previous, not the same receiver. // // HOWEVER: newTestError(400).Wrap(newTestError(500)).Wrap(newTestError(600)) // chains on the same receiver *testError(400) twice — clobbering 500 with 600. // The tests expect all three to appear in the chain, so Wrap must build a NEW // wrapper each time, not mutate in place. // Wrap appends cause to the tail of this error's chain and returns e. // This makes A.Wrap(B).Wrap(C) produce the chain A→B→C (not A→C), // which is what the tests expect when reading back all codes in order. func (e *testError) Wrap(cause error) *testError { // Walk to the innermost *Error in this node's chain and attach cause there. tail := e.baseErr for tail.cause != nil { if next, ok := tail.cause.(*baseErr); ok { tail = next } else if next, ok := tail.cause.(*testError); ok { tail = next.baseErr } else if next, ok := tail.cause.(*specialError); ok { tail = next.baseErr } else { break } } tail.cause = cause return e } // Unwrap delegates to the embedded *Error so Walk/errors.As can traverse the chain. func (e *testError) Unwrap() error { return e.baseErr.Unwrap() } // specialError is another custom error type for testing generics. type specialError struct { *baseErr // anonymous embed; field name is "baseErr", not "Error" priority int } func newSpecialError(priority int) *specialError { return &specialError{ baseErr: Named("SpecialError"), priority: priority, } } // Wrap appends cause to the tail of this error's chain and returns e. func (e *specialError) Wrap(cause error) *specialError { tail := e.baseErr for tail.cause != nil { if next, ok := tail.cause.(*baseErr); ok { tail = next } else if next, ok := tail.cause.(*testError); ok { tail = next.baseErr } else if next, ok := tail.cause.(*specialError); ok { tail = next.baseErr } else { break } } tail.cause = cause return e } // Unwrap delegates to the embedded *Error so Walk/errors.As can traverse the chain. func (e *specialError) Unwrap() error { return e.baseErr.Unwrap() } // simpleError is a non-*Error type that implements error type simpleError struct { msg string } func (e simpleError) Error() string { return e.msg } func TestAsType(t *testing.T) { tests := []struct { name string err error wantFound bool wantCode int }{ { name: "found testError", err: newTestError(404), wantFound: true, wantCode: 404, }, { name: "found wrapped testError", err: newTestError(500).Wrap(newTestError(400)), wantFound: true, wantCode: 500, }, { name: "not found wrong type", err: newSpecialError(1), wantFound: false, }, { name: "nil error", err: nil, wantFound: false, }, { name: "standard error", err: errors.New("standard"), wantFound: false, }, { name: "simpleError type", err: simpleError{msg: "simple"}, wantFound: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, found := AsType[*testError](tt.err) if found != tt.wantFound { t.Errorf("AsType() found = %v, want %v", found, tt.wantFound) } if found && got.code != tt.wantCode { t.Errorf("AsType() code = %v, want %v", got.code, tt.wantCode) } }) } } func TestIsType(t *testing.T) { tests := []struct { name string err error want bool }{ { name: "direct match", err: newTestError(404), want: true, }, { name: "wrapped match", err: newSpecialError(1).Wrap(newTestError(500)), want: true, }, { name: "deep wrapped match", err: New("top").Wrap(newSpecialError(2).Wrap(newTestError(600))), want: true, }, { name: "no match", err: newSpecialError(1), want: false, }, { name: "nil error", err: nil, want: false, }, { name: "standard error", err: errors.New("standard"), want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := IsType[*testError](tt.err); got != tt.want { t.Errorf("IsType() = %v, want %v", got, tt.want) } }) } } func TestFindType(t *testing.T) { tests := []struct { name string err error predicate func(*testError) bool wantFound bool wantCode int }{ { name: "match by code", err: newTestError(400).Wrap(newTestError(500)), predicate: func(e *testError) bool { return e.code == 500 }, wantFound: true, wantCode: 500, }, { name: "match first", err: newTestError(400).Wrap(newTestError(500)), predicate: func(e *testError) bool { return e.code == 400 }, wantFound: true, wantCode: 400, }, { name: "no match", err: newTestError(400), predicate: func(e *testError) bool { return e.code == 404 }, wantFound: false, }, { name: "nil predicate", err: newTestError(400), predicate: nil, wantFound: false, }, { name: "wrong type", err: newSpecialError(1), predicate: func(e *testError) bool { return true }, wantFound: false, }, { name: "deep chain match", err: New("top").Wrap(newSpecialError(2)).Wrap(newTestError(404).Wrap(newTestError(500))), predicate: func(e *testError) bool { return e.code == 404 }, wantFound: true, wantCode: 404, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, found := FindType(tt.err, tt.predicate) if found != tt.wantFound { t.Errorf("FindType() found = %v, want %v", found, tt.wantFound) } if found && got.code != tt.wantCode { t.Errorf("FindType() code = %v, want %v", got.code, tt.wantCode) } }) } } func TestMap(t *testing.T) { tests := []struct { name string err error fn func(*testError) int expected []int }{ { name: "single error", err: newTestError(400), fn: func(e *testError) int { return e.code }, expected: []int{400}, }, { name: "multiple errors in chain", err: newTestError(400).Wrap(newTestError(500)).Wrap(newTestError(600)), fn: func(e *testError) int { return e.code }, expected: []int{400, 500, 600}, }, { name: "mixed error types", err: newTestError(400).Wrap(newSpecialError(1)).Wrap(newTestError(500)), fn: func(e *testError) int { return e.code }, expected: []int{400, 500}, }, { name: "nil error", err: nil, fn: func(e *testError) int { return e.code }, expected: []int{}, }, { name: "no matching type", err: newSpecialError(1), fn: func(e *testError) int { return e.code }, expected: []int{}, }, { name: "nil function", err: newTestError(400), fn: nil, expected: []int{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := Map(tt.err, tt.fn) if len(got) != len(tt.expected) { t.Errorf("Map() length = %d, want %d", len(got), len(tt.expected)) return } for i := range got { if got[i] != tt.expected[i] { t.Errorf("Map()[%d] = %d, want %d", i, got[i], tt.expected[i]) } } }) } } func TestReduce(t *testing.T) { tests := []struct { name string err error initial int fn func(*testError, int) int expected int }{ { name: "sum codes", err: newTestError(400).Wrap(newTestError(500)).Wrap(newTestError(600)), initial: 0, fn: func(e *testError, acc int) int { return acc + e.code }, expected: 1500, }, { name: "max code", err: newTestError(400).Wrap(newTestError(500)).Wrap(newTestError(300)), initial: 0, fn: func(e *testError, acc int) int { if e.code > acc { return e.code } return acc }, expected: 500, }, { name: "nil error", err: nil, initial: 42, fn: func(e *testError, acc int) int { return acc + e.code }, expected: 42, }, { name: "no matching type", err: newSpecialError(1), initial: 10, fn: func(e *testError, acc int) int { return acc + e.code }, expected: 10, }, { name: "nil function", err: newTestError(400), initial: 5, fn: nil, expected: 5, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := Reduce(tt.err, tt.initial, tt.fn) if got != tt.expected { t.Errorf("Reduce() = %v, want %v", got, tt.expected) } }) } } func TestFilter(t *testing.T) { tests := []struct { name string err error expected int // count of testError }{ { name: "single error", err: newTestError(400), expected: 1, }, { name: "multiple test errors", err: newTestError(400).Wrap(newTestError(500)).Wrap(newTestError(600)), expected: 3, }, { name: "mixed with other types", err: newTestError(400).Wrap(newSpecialError(1)).Wrap(newTestError(500)), expected: 2, }, { name: "no test errors", err: newSpecialError(1), expected: 0, }, { name: "nil error", err: nil, expected: 0, }, { name: "standard error", err: errors.New("standard"), expected: 0, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := Filter[*testError](tt.err) if len(got) != tt.expected { t.Errorf("Filter() length = %d, want %d", len(got), tt.expected) } }) } } func TestFirstOfType(t *testing.T) { tests := []struct { name string err error wantFound bool wantCode int }{ { name: "first is testError", err: newTestError(400).Wrap(newSpecialError(1)), wantFound: true, wantCode: 400, }, { name: "testError after other type", err: newSpecialError(1).Wrap(newTestError(500)), wantFound: true, wantCode: 500, }, { name: "multiple test errors", err: newTestError(400).Wrap(newTestError(500)), wantFound: true, wantCode: 400, }, { name: "no test error", err: newSpecialError(1), wantFound: false, }, { name: "nil error", err: nil, wantFound: false, }, { name: "deep chain", err: New("top").Wrap(newSpecialError(2)).Wrap(newTestError(404)), wantFound: true, wantCode: 404, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, found := FirstOfType[*testError](tt.err) if found != tt.wantFound { t.Errorf("FirstOfType() found = %v, want %v", found, tt.wantFound) } if found && got.code != tt.wantCode { t.Errorf("FirstOfType() code = %v, want %v", got.code, tt.wantCode) } }) } } func TestContains(t *testing.T) { targetErr := New("target error") otherErr := New("other error") stdErr := errors.New("std error") tests := []struct { name string err error targets []error want bool }{ { name: "contains target", err: targetErr, targets: []error{targetErr}, want: true, }, { name: "contains in wrapped chain", err: newTestError(400).Wrap(targetErr), targets: []error{targetErr}, want: true, }, { name: "does not contain", err: otherErr, targets: []error{targetErr}, want: false, }, { name: "multiple targets - second matches", err: targetErr, targets: []error{otherErr, targetErr}, want: true, }, { name: "nil error", err: nil, targets: []error{targetErr}, want: false, }, { name: "empty targets", err: targetErr, targets: []error{}, want: false, }, { name: "standard error target", err: New("wrapper").Wrap(stdErr), targets: []error{stdErr}, want: true, }, { name: "named error match", err: Named("TestError"), targets: []error{Named("TestError")}, want: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := Contains(tt.err, tt.targets...); got != tt.want { t.Errorf("Contains() = %v, want %v", got, tt.want) } }) } } func TestJoinErrors(t *testing.T) { err1 := New("error 1") err2 := New("error 2") err3 := New("error 3") tests := []struct { name string errs []error keyValues []interface{} wantNil bool wantCount int // expected number of errors in the result (0 if nil) }{ { name: "join two errors", errs: []error{err1, err2}, wantNil: false, wantCount: 2, }, { name: "join single error", errs: []error{err1}, wantNil: false, wantCount: 1, }, { name: "all nil errors", errs: []error{nil, nil}, wantNil: true, wantCount: 0, }, { name: "empty slice", errs: []error{}, wantNil: true, wantCount: 0, }, { name: "join with context", errs: []error{err1, err2}, keyValues: []interface{}{"key", "value", "operation", "test"}, wantNil: false, wantCount: 2, }, { name: "single error with context", errs: []error{err1}, keyValues: []interface{}{"user", "123"}, wantNil: false, wantCount: 1, }, { name: "nil and non-nil mix", errs: []error{nil, err1, nil, err2, err3}, wantNil: false, wantCount: 3, }, { name: "mix of standard and custom errors", errs: []error{err1, errors.New("std error"), err2}, wantNil: false, wantCount: 3, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := JoinErrors(tt.errs, tt.keyValues...) if (got == nil) != tt.wantNil { t.Errorf("JoinErrors() returned nil = %v, want %v", got == nil, tt.wantNil) } if !tt.wantNil { // Check if it's a MultiError with correct count if multi, ok := got.(*MultiError); ok { if multi.Count() != tt.wantCount { t.Errorf("JoinErrors() MultiError count = %d, want %d", multi.Count(), tt.wantCount) } } else if tt.wantCount > 1 { t.Errorf("JoinErrors() should return *MultiError for multiple errors, got %T", got) } // Verify context was added if keyValues provided if len(tt.keyValues) > 0 { // For multiple errors, the wrapper *Error wraps the MultiError if e, ok := got.(*Error); ok { ctx := e.Context() if len(ctx) == 0 { t.Error("JoinErrors() with context should have context values") } } } } }) } } // Benchmark tests for generics func BenchmarkAsType(b *testing.B) { err := newTestError(404) b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = AsType[*testError](err) } } func BenchmarkIsType(b *testing.B) { err := newTestError(404) b.ResetTimer() for i := 0; i < b.N; i++ { _ = IsType[*testError](err) } } func BenchmarkFindType(b *testing.B) { err := newTestError(400).Wrap(newTestError(500)).Wrap(newTestError(600)) predicate := func(e *testError) bool { return e.code == 500 } b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = FindType(err, predicate) } } func BenchmarkMap(b *testing.B) { err := newTestError(400).Wrap(newTestError(500)).Wrap(newTestError(600)) fn := func(e *testError) int { return e.code } b.ResetTimer() for i := 0; i < b.N; i++ { _ = Map(err, fn) } } func BenchmarkFilter(b *testing.B) { err := newTestError(400).Wrap(newSpecialError(1)).Wrap(newTestError(500)) b.ResetTimer() for i := 0; i < b.N; i++ { _ = Filter[*testError](err) } } func BenchmarkReduce(b *testing.B) { err := newTestError(400).Wrap(newTestError(500)).Wrap(newTestError(600)) fn := func(e *testError, acc int) int { return acc + e.code } b.ResetTimer() for i := 0; i < b.N; i++ { _ = Reduce(err, 0, fn) } } func BenchmarkFirstOfType(b *testing.B) { err := newSpecialError(1).Wrap(newTestError(400)).Wrap(newTestError(500)) b.ResetTimer() for i := 0; i < b.N; i++ { _, _ = FirstOfType[*testError](err) } } func BenchmarkJoinErrors(b *testing.B) { errs := []error{New("error 1"), New("error 2"), New("error 3")} b.ResetTimer() for i := 0; i < b.N; i++ { _ = JoinErrors(errs) } } golang-github-olekukonko-errors-1.3.0/go.mod000066400000000000000000000000551517267734700210640ustar00rootroot00000000000000module github.com/olekukonko/errors go 1.21 golang-github-olekukonko-errors-1.3.0/group.go000066400000000000000000000066251517267734700214520ustar00rootroot00000000000000// Group runs multiple functions concurrently and collects all errors into a // *MultiError. It is the error-aware counterpart to sync/errgroup: errgroup // stops at the first failure; Group collects every failure. package errors import ( "context" "sync" ) // Group runs goroutines concurrently and collects every error they return. // The zero value is ready to use; options may be applied via NewGroup. // // Example — fan-out with full error collection: // // g := errors.NewGroup() // g.Go(func() error { return validateUser(id) }) // g.Go(func() error { return validatePerms(id) }) // if err := g.Wait(); err != nil { // // err is *MultiError containing all failures // log.Println(err) // } type Group struct { wg sync.WaitGroup errs *MultiError ctx context.Context cancel context.CancelFunc cancelOnFirst bool } // GroupOption configures a Group. type GroupOption func(*Group) // GroupWithContext attaches ctx to the Group. The ctx is passed to // context-aware Go calls (GoCtx). If cancelOnFirst is true, the context // is cancelled as soon as the first error is returned by any goroutine — // useful for "cancel siblings on first failure" patterns. func GroupWithContext(ctx context.Context, cancelOnFirst bool) GroupOption { return func(g *Group) { g.ctx, g.cancel = context.WithCancel(ctx) g.cancelOnFirst = cancelOnFirst } } // GroupWithLimit sets a maximum error limit on the underlying MultiError. // Errors beyond the limit are dropped. func GroupWithLimit(n int) GroupOption { return func(g *Group) { g.errs = NewMultiError(WithLimit(n)) } } // NewGroup creates a Group with the given options applied. func NewGroup(opts ...GroupOption) *Group { g := &Group{ errs: NewMultiError(), } for _, o := range opts { o(g) } if g.ctx == nil { g.ctx = context.Background() } return g } // Go starts fn in a new goroutine. Errors returned by fn are collected; // nil returns are ignored. Thread-safe: MultiError.Add handles its own locking. // cancelOnFirst is read-only after construction so no lock is needed. func (g *Group) Go(fn func() error) { g.wg.Add(1) go func() { defer g.wg.Done() if err := fn(); err != nil { g.errs.Add(err) // MultiError.Add is internally mutex-protected if g.cancelOnFirst && g.cancel != nil { g.cancel() } } }() } // GoCtx starts fn in a new goroutine, passing the group's context. // If the group was created with GroupWithContext, fn receives a context // that is cancelled when cancelOnFirst triggers or the parent is done. func (g *Group) GoCtx(fn func(ctx context.Context) error) { g.wg.Add(1) go func() { defer g.wg.Done() if err := fn(g.ctx); err != nil { g.errs.Add(err) // MultiError.Add is internally mutex-protected if g.cancelOnFirst && g.cancel != nil { g.cancel() } } }() } // Wait blocks until all goroutines have finished and returns a *MultiError // containing every error collected, or nil if all succeeded. // Always returns *MultiError (never collapses to a raw error) so callers // can reliably type-assert the result. func (g *Group) Wait() error { g.wg.Wait() if g.cancel != nil { g.cancel() // release context resources } if !g.errs.Has() { return nil } return g.errs } // Errors returns a snapshot of errors collected so far. // Safe to call concurrently with Go/GoCtx; may be incomplete before Wait returns. func (g *Group) Errors() []error { return g.errs.Errors() } golang-github-olekukonko-errors-1.3.0/group_test.go000066400000000000000000000107441517267734700225060ustar00rootroot00000000000000package errors import ( "context" "fmt" "strings" "sync/atomic" "testing" "time" ) func TestGroupAllSucceed(t *testing.T) { g := NewGroup() g.Go(func() error { return nil }) g.Go(func() error { return nil }) g.Go(func() error { return nil }) if err := g.Wait(); err != nil { t.Errorf("expected nil, got: %v", err) } } func TestGroupCollectsAllErrors(t *testing.T) { g := NewGroup() g.Go(func() error { return New("error one") }) g.Go(func() error { return nil }) g.Go(func() error { return New("error two") }) g.Go(func() error { return New("error three") }) err := g.Wait() if err == nil { t.Fatal("expected errors, got nil") } multi, ok := err.(*MultiError) if !ok { t.Fatalf("expected *MultiError, got %T", err) } if multi.Count() != 3 { t.Errorf("expected 3 errors, got %d", multi.Count()) } } func TestGroupSingleError(t *testing.T) { g := NewGroup() g.Go(func() error { return New("only error") }) err := g.Wait() if err == nil { t.Fatal("expected error, got nil") } // Wait always returns *MultiError so callers can reliably type-assert. multi, ok := err.(*MultiError) if !ok { t.Fatalf("expected *MultiError, got %T", err) } if multi.Count() != 1 { t.Errorf("expected 1 error, got %d", multi.Count()) } if !strings.Contains(multi.Error(), "only error") { t.Errorf("unexpected message: %q", multi.Error()) } } func TestGroupGoCtx(t *testing.T) { ctx := context.Background() g := NewGroup(GroupWithContext(ctx, false)) var received atomic.Int32 g.GoCtx(func(ctx context.Context) error { if ctx == nil { return New("nil context") } received.Add(1) return nil }) g.GoCtx(func(ctx context.Context) error { received.Add(1) return nil }) if err := g.Wait(); err != nil { t.Errorf("unexpected error: %v", err) } if received.Load() != 2 { t.Errorf("expected 2 goroutines to run, got %d", received.Load()) } } func TestGroupCancelOnFirst(t *testing.T) { ctx := context.Background() g := NewGroup(GroupWithContext(ctx, true)) var started atomic.Int32 // First goroutine errors immediately. g.GoCtx(func(ctx context.Context) error { started.Add(1) return New("first failure") }) // Second goroutine checks ctx cancellation after a small delay. g.GoCtx(func(ctx context.Context) error { started.Add(1) select { case <-ctx.Done(): // Context was cancelled by first failure — return nil // to show cancellation was observed. return nil case <-time.After(200 * time.Millisecond): return New("second should have been cancelled") } }) err := g.Wait() // Only the first error should be collected; second observed cancellation. if err == nil { t.Fatal("expected at least one error") } if started.Load() != 2 { t.Errorf("expected both goroutines to start, got %d", started.Load()) } } func TestGroupWithLimit(t *testing.T) { g := NewGroup(GroupWithLimit(2)) for i := 0; i < 10; i++ { i := i g.Go(func() error { return fmt.Errorf("error %d", i) }) } err := g.Wait() if err == nil { t.Fatal("expected errors, got nil") } multi, ok := err.(*MultiError) if !ok { t.Fatalf("expected *MultiError, got %T", err) } if multi.Count() > 2 { t.Errorf("expected at most 2 errors due to limit, got %d", multi.Count()) } } func TestGroupErrors(t *testing.T) { g := NewGroup() g.Go(func() error { return New("a") }) g.Go(func() error { return New("b") }) _ = g.Wait() errs := g.Errors() if len(errs) != 2 { t.Errorf("expected 2 errors from Errors(), got %d", len(errs)) } } func TestGroupReuseAfterWait(t *testing.T) { g := NewGroup() g.Go(func() error { return New("round one") }) err1 := g.Wait() if err1 == nil { t.Fatal("expected error in round one") } // Second round — verify group can be reused. g.Go(func() error { return nil }) err2 := g.Wait() // After reuse the old errors are still present (Group accumulates). // This is expected behaviour; document it in the test. _ = err2 } func TestGroupConcurrentSafety(t *testing.T) { g := NewGroup() for i := 0; i < 100; i++ { i := i g.Go(func() error { if i%2 == 0 { // Use unique messages so MultiError.Add deduplication does not // collapse them — each goroutine index produces a distinct string. return fmt.Errorf("even error %d", i) } return nil }) } err := g.Wait() if err == nil { t.Fatal("expected errors from 50 failing goroutines") } multi, ok := err.(*MultiError) if !ok { t.Fatalf("expected *MultiError, got %T", err) } if multi.Count() != 50 { t.Errorf("expected 50 errors, got %d", multi.Count()) } } golang-github-olekukonko-errors-1.3.0/helper.go000066400000000000000000000274031517267734700215720ustar00rootroot00000000000000package errors import ( "context" "errors" "fmt" "strings" "time" ) // As wraps errors.As, using custom type assertion for *Error types. // Falls back to standard errors.As for non-*Error types. // Returns false if either err or target is nil. func As(err error, target interface{}) bool { if err == nil || target == nil { return false } // First try our custom *Error handling if e, ok := err.(*Error); ok { return e.As(target) } // Fall back to standard errors.As return errors.As(err, target) } // Code returns the status code of an error, if it is an *Error. // Returns 500 as a default for non-*Error types to indicate an internal error. func Code(err error) int { if e, ok := err.(*Error); ok { return e.Code() } return DefaultCode } // Context extracts the context map from an error, if it is an *Error. // Returns nil for non-*Error types or if no context is present. func Context(err error) map[string]interface{} { if e, ok := err.(*Error); ok { return e.Context() } return nil } // Convert transforms any error into an *Error, preserving its message and wrapping it if needed. // Returns nil if the input is nil; returns the original if already an *Error. // Uses multiple strategies: direct assertion, errors.As, manual unwrapping, and fallback creation. func Convert(err error) *Error { if err == nil { return nil } // First try direct type assertion (fast path) if e, ok := err.(*Error); ok { return e } // Try using errors.As (more flexible) var e *Error if errors.As(err, &e) { return e } // Manual unwrapping as fallback visited := make(map[error]bool) for unwrapped := err; unwrapped != nil; { if visited[unwrapped] { break // Cycle detected } visited[unwrapped] = true if e, ok := unwrapped.(*Error); ok { return e } unwrapped = errors.Unwrap(unwrapped) } // Final fallback: create new error with original message and wrap it return New(err.Error()).Wrap(err) } // Count returns the occurrence count of an error, if it is an *Error. // Returns 0 for non-*Error types. func Count(err error) uint64 { if e, ok := err.(*Error); ok { return e.Count() } return 0 } // Find searches the error chain for the first error matching pred. // Returns nil if no match is found or pred is nil; traverses both Unwrap() and Cause() chains. func Find(err error, pred func(error) bool) error { for current := err; current != nil; { if pred(current) { return current } // Attempt to unwrap using Unwrap() or Cause() switch v := current.(type) { case interface{ Unwrap() error }: current = v.Unwrap() case interface{ Cause() error }: current = v.Cause() default: return nil } } return nil } // From transforms any error into an *Error, preserving its message and wrapping it if needed. // Alias of Convert; returns nil if input is nil, original if already an *Error. func From(err error) *Error { return Convert(err) } // FromContext creates an *Error from a context and an existing error. // Enhances the error with context info: timeout status, deadline, or cancellation. // Returns nil if input error is nil; does not store context values directly. func FromContext(ctx context.Context, err error) *Error { if err == nil { return nil } e := New(err.Error()) // Handle context errors switch ctx.Err() { case context.DeadlineExceeded: e.WithTimeout() if deadline, ok := ctx.Deadline(); ok { e.With("deadline", deadline.Format(time.RFC3339)) } case context.Canceled: e.With("cancelled", true) } return e } // Category returns the category of an error, if it is an *Error. // Returns an empty string for non-*Error types or unset categories. func Category(err error) string { if e, ok := err.(*Error); ok { return e.category } return "" } // Has checks if an error contains meaningful content. // Returns true for non-nil standard errors or *Error with content (msg, name, template, or cause). func Has(err error) bool { if e, ok := err.(*Error); ok { return e.Has() } return err != nil } // HasContextKey checks if the error's context contains the specified key. // Returns false for non-*Error types or if the key is not present in the context. func HasContextKey(err error, key string) bool { if e, ok := err.(*Error); ok { ctx := e.Context() if ctx != nil { _, exists := ctx[key] return exists } } return false } // Is wraps errors.Is, using custom matching for *Error types. // Falls back to standard errors.Is for non-*Error types; returns true if err equals target. func Is(err, target error) bool { if err == nil || target == nil { return err == target } if e, ok := err.(*Error); ok { return e.Is(target) } // Use standard errors.Is for non-Error types return errors.Is(err, target) } // IsError checks if an error is an instance of *Error. // Returns true only for this package's custom error type; false for nil or other types. func IsError(err error) bool { _, ok := err.(*Error) return ok } // IsEmpty checks if an error has no meaningful content. // Returns true for nil errors, empty *Error instances, or standard errors with whitespace-only messages. func IsEmpty(err error) bool { if err == nil { return true } if e, ok := err.(*Error); ok { return e.IsEmpty() } return strings.TrimSpace(err.Error()) == "" } // IsNull checks if an error is nil or represents a NULL value. // Delegates to *Error’s IsNull for custom errors; uses sqlNull for others. func IsNull(err error) bool { if err == nil { return true } if e, ok := err.(*Error); ok { return e.IsNull() } return sqlNull(err) } // IsRetryable checks if an error is retryable. // For *Error, checks context for retry flag; for others, looks for "retry" or timeout in message. // Returns false for nil errors; thread-safe for *Error types. func IsRetryable(err error) bool { if err == nil { return false } if e, ok := err.(*Error); ok { e.mu.RLock() defer e.mu.RUnlock() // Check smallContext directly if context map isn’t populated for i := int32(0); i < e.smallCount; i++ { if e.smallContext[i].key == ctxRetry { if val, ok := e.smallContext[i].value.(bool); ok { return val } } } // Check regular context if e.context != nil { if val, ok := e.context[ctxRetry].(bool); ok { return val } } // Check cause recursively if e.cause != nil { return IsRetryable(e.cause) } } lowerMsg := strings.ToLower(err.Error()) return IsTimeout(err) || strings.Contains(lowerMsg, "retry") } // IsTimeout checks if an error indicates a timeout. // For *Error, checks context for timeout flag; for others, looks for "timeout" in message. // Returns false for nil errors. func IsTimeout(err error) bool { if err == nil { return false } if e, ok := err.(*Error); ok { if val, ok := e.Context()[ctxTimeout].(bool); ok { return val } } return strings.Contains(strings.ToLower(err.Error()), "timeout") } // Merge combines multiple errors into a single *Error. // Aggregates messages with "; " separator, merges contexts and stacks; returns nil if no errors provided. func Merge(errs ...error) *Error { if len(errs) == 0 { return nil } var messages []string combined := New("") for _, err := range errs { if err == nil { continue } messages = append(messages, err.Error()) if e, ok := err.(*Error); ok { if e.stack != nil && combined.stack == nil { combined.WithStack() // Capture stack from first *Error with stack } if ctx := e.Context(); ctx != nil { for k, v := range ctx { combined.With(k, v) } } if e.cause != nil { combined.Wrap(e.cause) } } else { combined.Wrap(err) } } if len(messages) > 0 { combined.msg = strings.Join(messages, "; ") } return combined } // Name returns the name of an error, if it is an *Error. // Returns an empty string for non-*Error types or unset names. func Name(err error) string { if e, ok := err.(*Error); ok { return e.name } return "" } // UnwrapAll returns a slice of all errors in the chain, including the root error. // Traverses both Unwrap() and Cause() chains; returns nil if err is nil. func UnwrapAll(err error) []error { if err == nil { return nil } if e, ok := err.(*Error); ok { return e.UnwrapAll() } var result []error Walk(err, func(e error) { result = append(result, e) }) return result } // Stack extracts the stack trace from an error, if it is an *Error. // Returns nil for non-*Error types or if no stack is present. func Stack(err error) []string { if e, ok := err.(*Error); ok { return e.Stack() } return nil } // Transform applies transformations to an error, returning a new *Error. // Creates a new *Error from non-*Error types before applying fn; returns nil if err is nil. func Transform(err error, fn func(*Error)) *Error { if err == nil { return nil } if e, ok := err.(*Error); ok { newErr := e.Copy() fn(newErr) return newErr } // If not an *Error, create a new one and transform it newErr := New(err.Error()) fn(newErr) return newErr } // Unwrap returns the underlying cause of an error, if it implements Unwrap. // For *Error, returns cause; for others, returns the error itself; nil if err is nil. func Unwrap(err error) error { for current := err; current != nil; { if e, ok := current.(*Error); ok { if e.cause == nil { return current } current = e.cause } else { return current } } return nil } // Walk traverses the error chain, applying fn to each error. // Supports both Unwrap() and Cause() interfaces; stops at nil or non-unwrappable errors. func Walk(err error, fn func(error)) { for current := err; current != nil; { fn(current) // Attempt to unwrap using Unwrap() or Cause() switch v := current.(type) { case interface{ Unwrap() error }: current = v.Unwrap() case interface{ Cause() error }: current = v.Cause() default: return } } } // With adds a key-value pair to an error's context, if it is an *Error. // Returns the original error unchanged if not an *Error; no-op for non-*Error types. func With(err error, key string, value interface{}) error { if e, ok := err.(*Error); ok { return e.With(key, value) } return err } // WithStack converts any error to an *Error and captures a stack trace. // Returns nil if input is nil; adds stack to existing *Error or wraps non-*Error types. func WithStack(err error) *Error { if err == nil { return nil } if e, ok := err.(*Error); ok { return e.WithStack() } return New(err.Error()).WithStack().Wrap(err) } // Wrap creates a new *Error that wraps another error with additional context. // Uses a copy of the provided wrapper *Error; returns nil if err is nil. func Wrap(err error, wrapper *Error) *Error { if err == nil { return nil } if wrapper == nil { wrapper = newError() } newErr := wrapper.Copy() newErr.cause = err return newErr } // Wrapf creates a new formatted *Error that wraps another error. // Formats the message and sets the cause; returns nil if err is nil. func Wrapf(err error, format string, args ...interface{}) *Error { if err == nil { return nil } e := newError() e.msg = fmt.Sprintf(format, args...) e.cause = err return e } // Err creates a new Error with the given message and wraps the provided error as its cause. func Err(msg string, err error) *Error { return New(msg).Wrap(err) } // Join returns an error that wraps the given errors. // Any nil error values are discarded. // Join returns nil if every error is nil. // The error formats as the concatenation of the errors' strings, separated by newlines. // The resulting error implements Unwrap() []error if multiple non-nil errors are present. func Join(errs ...error) error { nonNil := make([]error, 0, len(errs)) for _, err := range errs { if err != nil { nonNil = append(nonNil, err) } } switch len(nonNil) { case 0: return nil case 1: return nonNil[0] default: multi := NewMultiError() for _, err := range nonNil { multi.Add(err) } return multi } } golang-github-olekukonko-errors-1.3.0/helper_test.go000066400000000000000000000156151517267734700226330ustar00rootroot00000000000000package errors import ( "database/sql" "errors" "runtime" "strings" "sync" "testing" "time" ) var testMu sync.Mutex // Protect global state changes // TestHelperWarmStackPool verifies that WarmStackPool pre-populates the stack pool correctly. func TestHelperWarmStackPool(t *testing.T) { testMu.Lock() defer testMu.Unlock() // Save and restore original config originalConfig := currentConfig defer func() { currentConfig = originalConfig }() // Reinitialize stackPool with a nil-returning New function for this test stackPool = sync.Pool{ New: func() interface{} { return nil // Return nil when pool is empty }, } // Test disabled pooling currentConfig.disablePooling = true WarmStackPool(5) if got := stackPool.Get(); got != nil { t.Errorf("WarmStackPool should not populate when pooling is disabled, got %v", got) } // Reinitialize stackPool for enabled pooling test stackPool = sync.Pool{ New: func() interface{} { return make([]uintptr, currentConfig.stackDepth) }, } // Test enabled pooling currentConfig.disablePooling = false WarmStackPool(3) count := 0 for i := 0; i < 3; i++ { if stackPool.Get() != nil { count++ } } if count != 3 { t.Errorf("WarmStackPool should populate 3 items, got %d", count) } } // TestHelperCaptureStack verifies that captureStack captures the correct stack frames. func TestHelperCaptureStack(t *testing.T) { stack := captureStack(0) if len(stack) == 0 { t.Error("captureStack should capture at least one frame") } found := false frames := runtime.CallersFrames(stack) for { frame, more := frames.Next() if frame == (runtime.Frame{}) { break } if strings.Contains(frame.Function, "TestHelperCaptureStack") { found = true break } if !more { break } } if !found { t.Error("captureStack should include TestHelperCaptureStack in the stack") } } // TestHelperMin verifies the min helper function returns the smaller integer. func TestHelperMin(t *testing.T) { tests := []struct { a, b, want int }{ {1, 2, 1}, {5, 3, 3}, {0, 0, 0}, {-1, 1, -1}, } for _, tt := range tests { if got := min(tt.a, tt.b); got != tt.want { t.Errorf("min(%d, %d) = %d, want %d", tt.a, tt.b, got, tt.want) } } } // TestHelperClearMap verifies that clearMap empties a map. func TestHelperClearMap(t *testing.T) { m := map[string]interface{}{ "a": 1, "b": "test", } clearMap(m) if len(m) != 0 { t.Errorf("clearMap should empty the map, got %d items", len(m)) } } // TestHelperSqlNull verifies sqlNull detects SQL null types correctly. func TestHelperSqlNull(t *testing.T) { tests := []struct { name string value interface{} expected bool }{ {"nil", nil, true}, {"null string", sql.NullString{Valid: false}, true}, {"valid string", sql.NullString{String: "test", Valid: true}, false}, {"null time", sql.NullTime{Valid: false}, true}, {"valid time", sql.NullTime{Time: time.Now(), Valid: true}, false}, {"non-sql type", "test", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := sqlNull(tt.value); got != tt.expected { t.Errorf("sqlNull(%v) = %v, want %v", tt.value, got, tt.expected) } }) } } // TestHelperGetFuncName verifies getFuncName extracts function names correctly. func TestHelperGetFuncName(t *testing.T) { if got := getFuncName(nil); got != "unknown" { t.Errorf("getFuncName(nil) = %q, want 'unknown'", got) } if got := getFuncName(TestHelperGetFuncName); !strings.Contains(got, "TestHelperGetFuncName") { t.Errorf("getFuncName(TestHelperGetFuncName) = %q, want to contain 'TestHelperGetFuncName'", got) } } // TestHelperIsInternalFrame verifies isInternalFrame identifies internal frames. func TestHelperIsInternalFrame(t *testing.T) { tests := []struct { frame runtime.Frame expected bool }{ {runtime.Frame{Function: "runtime.main"}, true}, {runtime.Frame{Function: "reflect.ValueOf"}, true}, {runtime.Frame{File: "github.com/olekukonko/errors/errors.go"}, true}, {runtime.Frame{Function: "main.main"}, false}, } for _, tt := range tests { if got := isInternalFrame(tt.frame); got != tt.expected { t.Errorf("isInternalFrame(%v) = %v, want %v", tt.frame, got, tt.expected) } } } // TestHelperFormatError verifies FormatError produces the expected string output. func TestHelperFormatError(t *testing.T) { err := New("test").With("key", "value").Wrap(New("cause")) defer err.Free() formatted := FormatError(err) if !strings.Contains(formatted, "Error: test: cause") { t.Errorf("FormatError missing error message: %q", formatted) } if !strings.Contains(formatted, "Context:\n\tkey: value") { t.Errorf("FormatError missing context: %q", formatted) } if !strings.Contains(formatted, "Caused by:") { t.Errorf("FormatError missing cause: %q", formatted) } if FormatError(nil) != "" { t.Error("FormatError(nil) should return ''") } stdErr := errors.New("std error") if !strings.Contains(FormatError(stdErr), "Error: std error") { t.Errorf("FormatError for std error missing message: %q", FormatError(stdErr)) } } // TestHelperCaller verifies Caller returns the correct caller information. func TestHelperCaller(t *testing.T) { file, line, function := Caller(0) if !strings.Contains(file, "helper_test.go") { t.Errorf("Caller file = %q, want to contain 'helper_test.go'", file) } if line <= 0 { t.Errorf("Caller line = %d, want > 0", line) } if !strings.Contains(function, "TestHelperCaller") { t.Errorf("Caller function = %q, want to contain 'TestHelperCaller'", function) } } // TestHelperPackageIsEmpty verifies package-level IsEmpty behavior. func TestHelperPackageIsEmpty(t *testing.T) { tests := []struct { name string err error expected bool }{ {"nil error", nil, true}, {"empty std error", errors.New(""), true}, {"whitespace error", errors.New(" "), true}, {"non-empty std error", errors.New("test"), false}, {"empty custom error", New(""), true}, {"non-empty custom error", New("test"), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if e, ok := tt.err.(*Error); ok { defer e.Free() } if got := IsEmpty(tt.err); got != tt.expected { t.Errorf("IsEmpty() = %v, want %v", got, tt.expected) } }) } } // TestHelperPackageIsNull verifies package-level IsNull behavior. func TestHelperPackageIsNull(t *testing.T) { nullTime := sql.NullTime{Valid: false} validTime := sql.NullTime{Time: time.Now(), Valid: true} tests := []struct { name string err error expected bool }{ {"nil error", nil, true}, {"std error", errors.New("test"), false}, {"custom error with NULL", New("").With("time", nullTime), true}, {"custom error with valid", New("").With("time", validTime), false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if e, ok := tt.err.(*Error); ok { defer e.Free() } if got := IsNull(tt.err); got != tt.expected { t.Errorf("IsNull() = %v, want %v", got, tt.expected) } }) } } golang-github-olekukonko-errors-1.3.0/http.go000066400000000000000000000061271517267734700212720ustar00rootroot00000000000000package errors import ( "fmt" "net/http" ) // httpConfig holds resolved options for an HTTPError call. type httpConfig struct { fallbackCode int includeBody bool bodyFn func(error) string } // HTTPOption configures an HTTPError call. type HTTPOption func(*httpConfig) // WithFallbackCode sets the HTTP status used when err carries no valid code. // Default is 500 (Internal Server Error). func WithFallbackCode(code int) HTTPOption { return func(c *httpConfig) { c.fallbackCode = code } } // WithBody controls whether the error message is written as the response body. // Default is true. func WithBody(include bool) HTTPOption { return func(c *httpConfig) { c.includeBody = include } } // WithBodyFunc sets a custom function that produces the response body string // from the error. Overrides WithBody when set. // // Example — return JSON instead of plain text: // // errors.HTTPError(w, err, // errors.WithBodyFunc(func(e error) string { // return fmt.Sprintf(`{"error":%q}`, e.Error()) // }), // ) func WithBodyFunc(fn func(error) string) HTTPOption { return func(c *httpConfig) { c.bodyFn = fn } } // HTTPError writes err to w as an HTTP error response. // // Status code resolution (first match wins): // err is *Error with Code() in the valid HTTP range (100–599) // WithFallbackCode option (default 500) // // Content-Type defaults to text/plain unless WithBodyFunc provides content // that implies a different type (caller must set the header themselves in // that case — use WithBodyFunc + manual header setting). // // Example — simplest usage, plain text body, 500 fallback: // // errors.HTTPError(w, err) // // Example — custom fallback status: // // errors.HTTPError(w, err, errors.WithFallbackCode(http.StatusBadGateway)) // // Example — suppress body (header only): // // errors.HTTPError(w, err, errors.WithBody(false)) // // Example — JSON body: // // errors.HTTPError(w, err, // errors.WithBodyFunc(func(e error) string { // return fmt.Sprintf(`{"error":%q,"code":%d}`, // e.Error(), errors.HTTPStatusCode(e, 500)) // }), // ) func HTTPError(w http.ResponseWriter, err error, opts ...HTTPOption) { cfg := &httpConfig{ fallbackCode: http.StatusInternalServerError, includeBody: true, } for _, o := range opts { o(cfg) } code := HTTPStatusCode(err, cfg.fallbackCode) if cfg.bodyFn != nil { w.WriteHeader(code) if err != nil { _, _ = fmt.Fprint(w, cfg.bodyFn(err)) } return } w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.WriteHeader(code) if cfg.includeBody && err != nil { _, _ = fmt.Fprintln(w, err.Error()) } } // HTTPStatusCode returns the HTTP status code embedded in err. // If err is nil, has no code, or the code is outside the valid HTTP range // (100–599), defaultCode is returned. // // Example: // // status := errors.HTTPStatusCode(err, http.StatusInternalServerError) func HTTPStatusCode(err error, defaultCode int) int { if err == nil { return defaultCode } if e, ok := err.(*Error); ok { if c := e.Code(); c >= http.StatusContinue && c <= 599 { return c } } return defaultCode } golang-github-olekukonko-errors-1.3.0/http_test.go000066400000000000000000000063021517267734700223240ustar00rootroot00000000000000package errors import ( "fmt" "net/http" "net/http/httptest" "strings" "testing" ) func TestHTTPStatusCodeNilError(t *testing.T) { if got := HTTPStatusCode(nil, 500); got != 500 { t.Errorf("nil error: got %d, want 500", got) } } func TestHTTPStatusCodeWithValidCode(t *testing.T) { err := New("not found").WithCode(404) if got := HTTPStatusCode(err, 500); got != 404 { t.Errorf("got %d, want 404", got) } } func TestHTTPStatusCodeOutOfRange(t *testing.T) { // Code below 100 is not a valid HTTP status — should use fallback. err := New("bad").WithCode(50) if got := HTTPStatusCode(err, 500); got != 500 { t.Errorf("out-of-range code: got %d, want 500", got) } // Code above 599 also invalid. err2 := New("bad").WithCode(600) if got := HTTPStatusCode(err2, 503); got != 503 { t.Errorf("code 600: got %d, want 503", got) } } func TestHTTPStatusCodeNonErrorType(t *testing.T) { // stdlib error has no code — should use fallback. err := fmt.Errorf("plain error") if got := HTTPStatusCode(err, 502); got != 502 { t.Errorf("plain error: got %d, want 502", got) } } func TestHTTPErrorDefaultBehaviour(t *testing.T) { w := httptest.NewRecorder() err := New("something broke").WithCode(422) HTTPError(w, err) if w.Code != 422 { t.Errorf("status: got %d, want 422", w.Code) } if ct := w.Header().Get("Content-Type"); !strings.HasPrefix(ct, "text/plain") { t.Errorf("Content-Type: got %q, want text/plain prefix", ct) } if !strings.Contains(w.Body.String(), "something broke") { t.Errorf("body missing error message: %q", w.Body.String()) } } func TestHTTPErrorFallbackCode(t *testing.T) { w := httptest.NewRecorder() err := New("upstream unavailable") // no code set HTTPError(w, err, WithFallbackCode(http.StatusBadGateway)) if w.Code != http.StatusBadGateway { t.Errorf("status: got %d, want %d", w.Code, http.StatusBadGateway) } } func TestHTTPErrorNoBody(t *testing.T) { w := httptest.NewRecorder() err := New("internal error").WithCode(500) HTTPError(w, err, WithBody(false)) if w.Code != 500 { t.Errorf("status: got %d, want 500", w.Code) } if w.Body.Len() != 0 { t.Errorf("body should be empty, got %q", w.Body.String()) } } func TestHTTPErrorCustomBodyFunc(t *testing.T) { w := httptest.NewRecorder() err := New("bad input").WithCode(400) HTTPError(w, err, WithBodyFunc(func(e error) string { return fmt.Sprintf(`{"error":%q}`, e.Error()) })) if w.Code != 400 { t.Errorf("status: got %d, want 400", w.Code) } body := w.Body.String() if !strings.Contains(body, `"bad input"`) { t.Errorf("JSON body missing error: %q", body) } } func TestHTTPErrorNilError(t *testing.T) { w := httptest.NewRecorder() HTTPError(w, nil, WithFallbackCode(200)) if w.Code != 200 { t.Errorf("nil error: got %d, want 200", w.Code) } if w.Body.Len() != 0 { t.Errorf("nil error body should be empty, got %q", w.Body.String()) } } func TestHTTPErrorSentinelError(t *testing.T) { ErrForbidden := Const("forbidden", "access denied") w := httptest.NewRecorder() // Sentinel has no code — uses fallback. HTTPError(w, ErrForbidden, WithFallbackCode(http.StatusForbidden)) if w.Code != http.StatusForbidden { t.Errorf("sentinel: got %d, want %d", w.Code, http.StatusForbidden) } } golang-github-olekukonko-errors-1.3.0/inspect.go000066400000000000000000000123161517267734700217550ustar00rootroot00000000000000// Human-readable error inspection. Output is written to caller-supplied // io.Writer values; this library never owns stdout or stderr. package errors import ( stderrs "errors" "fmt" "io" "os" "strings" ) // inspectConfig holds resolved options for a single Inspect call. type inspectConfig struct { w io.Writer stackFrames int maxDepth int } // InspectOption configures an Inspect call. type InspectOption func(*inspectConfig) // WithStackFrames sets the maximum number of stack frames printed per error node. // Default is 3. func WithStackFrames(n int) InspectOption { return func(c *inspectConfig) { c.stackFrames = n } } // WithMaxDepth sets the maximum chain depth traversed before output is truncated. // Default is 10. func WithMaxDepth(n int) InspectOption { return func(c *inspectConfig) { c.maxDepth = n } } // Inspect writes a human-readable description of err to each writer in ws. // If no writers are supplied it defaults to os.Stderr. // Multiple writers are combined with io.MultiWriter so a single call can // write to a log file and a buffer simultaneously. // // Example — default (stderr): // // errors.Inspect(err) // // Example — write to a buffer for testing: // // var buf bytes.Buffer // errors.Inspect(err, &buf) // // Example — write to both stderr and a file: // // errors.Inspect(err, os.Stderr, logFile) // // Example — customise stack depth: // // errors.Inspect(err, os.Stderr, errors.WithStackFrames(5)) // // Note: InspectOption values must come after all io.Writer values. Any value // that is neither an io.Writer nor an InspectOption is silently ignored. func Inspect(err error, targets ...interface{}) { cfg := &inspectConfig{ stackFrames: 3, maxDepth: 10, } var writers []io.Writer for _, t := range targets { switch v := t.(type) { case InspectOption: v(cfg) case io.Writer: writers = append(writers, v) } } if len(writers) == 0 { writers = []io.Writer{os.Stderr} } if len(writers) == 1 { cfg.w = writers[0] } else { cfg.w = io.MultiWriter(writers...) } writeInspect(cfg, err) } // InspectError is a convenience wrapper for *Error that calls Inspect. // Kept for backwards compatibility; prefer Inspect for new code. func InspectError(err *Error, targets ...interface{}) { Inspect(err, targets...) } // writeInspect does the actual formatting. func writeInspect(cfg *inspectConfig, err error) { w := cfg.w if err == nil { fmt.Fprintln(w, "no error") return } fmt.Fprintf(w, "\n=== error inspection ===\n") fmt.Fprintf(w, "type: %T\n", err) fmt.Fprintf(w, "message: %v\n", err) switch e := err.(type) { case *Error: writeChain(cfg, e) writeDiagnostics(cfg, err) case *MultiError: errs := e.Errors() fmt.Fprintf(w, "errors: %d\n", len(errs)) for i, sub := range errs { fmt.Fprintf(w, "\n--- error %d ---\n", i+1) writeSingle(cfg, sub, 0) } writeDiagnostics(cfg, err) default: writeSingle(cfg, err, 0) writeDiagnostics(cfg, err) } fmt.Fprintf(w, "========================\n\n") } // writeChain walks an *Error chain printing each node. func writeChain(cfg *inspectConfig, e *Error) { var current error = e depth := 0 for current != nil && depth <= cfg.maxDepth { writeSingle(cfg, current, depth) next := stderrs.Unwrap(current) if next == current || next == nil { break } current = next depth++ } if depth > cfg.maxDepth { fmt.Fprintf(cfg.w, " ... (chain truncated at depth %d)\n", cfg.maxDepth) } } // writeSingle prints one error node at the given indent depth. func writeSingle(cfg *inspectConfig, err error, depth int) { if err == nil { return } w := cfg.w pad := strings.Repeat(" ", depth) if depth > 0 { fmt.Fprintf(w, "%scause (%T): %v\n", pad, err, err) } e, ok := err.(*Error) if !ok { return } if n := e.Name(); n != "" { fmt.Fprintf(w, "%s name: %s\n", pad, n) } if cat := e.Category(); cat != "" { fmt.Fprintf(w, "%s category: %s\n", pad, cat) } if code := e.Code(); code != 0 { fmt.Fprintf(w, "%s code: %d\n", pad, code) } if ctx := e.Context(); len(ctx) > 0 { fmt.Fprintf(w, "%s context:\n", pad) for k, v := range ctx { fmt.Fprintf(w, "%s %s: %v\n", pad, k, v) } } if stack := e.Stack(); len(stack) > 0 { limit := cfg.stackFrames if len(stack) < limit { limit = len(stack) } fmt.Fprintf(w, "%s stack (top %d):\n", pad, limit) for i := 0; i < limit; i++ { fmt.Fprintf(w, "%s %s\n", pad, stack[i]) } if len(stack) > limit { fmt.Fprintf(w, "%s ... (%d more frames)\n", pad, len(stack)-limit) } } } // writeDiagnostics appends a short diagnostic summary. func writeDiagnostics(cfg *inspectConfig, err error) { var parts []string if IsRetryable(err) { parts = append(parts, "retryable") } if IsTimeout(err) { parts = append(parts, "timeout") } if code := getErrorCode(err); code != 0 { parts = append(parts, fmt.Sprintf("code=%d", code)) } if len(parts) > 0 { fmt.Fprintf(cfg.w, "diagnostics: %s\n", strings.Join(parts, ", ")) } } // getErrorCode traverses the error chain to find the first non-zero code. func getErrorCode(err error) int { if e, ok := err.(*Error); ok { if c := e.Code(); c != 0 { return c } } var target *Error if As(err, &target) && target != nil { return target.Code() } return 0 } golang-github-olekukonko-errors-1.3.0/inspect_test.go000066400000000000000000000100461517267734700230120ustar00rootroot00000000000000package errors import ( "bytes" "strings" "testing" ) func TestInspectNil(t *testing.T) { var buf bytes.Buffer Inspect(nil, &buf) if !strings.Contains(buf.String(), "no error") { t.Errorf("expected 'no error', got: %q", buf.String()) } } func TestInspectPlainError(t *testing.T) { var buf bytes.Buffer err := New("something went wrong").WithCode(500).With("user", "alice") Inspect(err, &buf) out := buf.String() if !strings.Contains(out, "something went wrong") { t.Errorf("missing message in output: %q", out) } if !strings.Contains(out, "code:") { t.Errorf("missing code in output: %q", out) } if !strings.Contains(out, "alice") { t.Errorf("missing context value in output: %q", out) } } func TestInspectNamedError(t *testing.T) { var buf bytes.Buffer err := Named("AuthError").WithCode(401) Inspect(err, &buf) out := buf.String() if !strings.Contains(out, "AuthError") { t.Errorf("missing name in output: %q", out) } if !strings.Contains(out, "401") { t.Errorf("missing code in output: %q", out) } if !strings.Contains(out, "code=401") { t.Errorf("missing diagnostics code in output: %q", out) } } func TestInspectChain(t *testing.T) { var buf bytes.Buffer cause := New("db timeout").WithTimeout() outer := New("request failed").Wrap(cause) Inspect(outer, &buf) out := buf.String() if !strings.Contains(out, "request failed") { t.Errorf("missing outer message: %q", out) } if !strings.Contains(out, "db timeout") { t.Errorf("missing cause message: %q", out) } if !strings.Contains(out, "timeout") { t.Errorf("missing timeout diagnostic: %q", out) } } func TestInspectMultiError(t *testing.T) { var buf bytes.Buffer m := NewMultiError() m.Add(New("error one")) m.Add(New("error two")) Inspect(m, &buf) out := buf.String() if !strings.Contains(out, "errors: 2") { t.Errorf("missing error count: %q", out) } if !strings.Contains(out, "error one") { t.Errorf("missing first error: %q", out) } if !strings.Contains(out, "error two") { t.Errorf("missing second error: %q", out) } } func TestInspectMultipleWriters(t *testing.T) { var buf1, buf2 bytes.Buffer err := New("dual write test") Inspect(err, &buf1, &buf2) if buf1.String() == "" { t.Error("buf1 should have received output") } if buf1.String() != buf2.String() { t.Errorf("both writers should receive identical output\nbuf1: %q\nbuf2: %q", buf1.String(), buf2.String()) } } func TestInspectWithStackFramesOption(t *testing.T) { var buf bytes.Buffer err := Trace("traced error") Inspect(err, &buf, WithStackFrames(1)) out := buf.String() // Should mention stack but respect the 1-frame limit if !strings.Contains(out, "stack") { t.Errorf("expected stack section in output: %q", out) } } func TestInspectWithMaxDepthOption(t *testing.T) { var buf bytes.Buffer // Build a 5-deep chain err := New("level 0") for i := 1; i <= 5; i++ { err = New("level " + string(rune('0'+i))).Wrap(err) } Inspect(err, &buf, WithMaxDepth(2)) out := buf.String() if !strings.Contains(out, "truncated") { t.Errorf("expected truncation message for deep chain: %q", out) } } func TestInspectRetryableDiagnostic(t *testing.T) { var buf bytes.Buffer err := New("flaky call").WithRetryable() Inspect(err, &buf) out := buf.String() if !strings.Contains(out, "retryable") { t.Errorf("expected retryable diagnostic: %q", out) } } func TestInspectDefaultsToStderr(t *testing.T) { // Just verify it doesn't panic with no writers supplied. // We can't capture stderr easily in a unit test, so we only check no panic. defer func() { if r := recover(); r != nil { t.Errorf("Inspect panicked with no writers: %v", r) } }() // Redirect would require os.Pipe tricks; just call with a buffer to keep // output off the test console while still exercising the code path. var buf bytes.Buffer Inspect(New("test"), &buf) } func TestInspectError(t *testing.T) { var buf bytes.Buffer err := Named("SomeError").WithCode(503) InspectError(err, &buf) out := buf.String() if !strings.Contains(out, "SomeError") { t.Errorf("InspectError missing name: %q", out) } } golang-github-olekukonko-errors-1.3.0/multi_error.go000066400000000000000000000271601517267734700226560ustar00rootroot00000000000000package errors import ( "bytes" "encoding/json" "fmt" "math/rand" "strings" "sync" "sync/atomic" ) // MultiError represents a thread-safe collection of errors with enhanced features. // Supports limits, sampling, and custom formatting for error aggregation. type MultiError struct { errors []error mu sync.RWMutex // Configuration fields limit int // Maximum number of errors to store (0 = unlimited) formatter ErrorFormatter // Custom formatting function for error string sampling bool // Whether sampling is enabled to limit error collection sampleRate uint32 // Sampling percentage (1-100) when sampling is enabled rand *rand.Rand // Random source for sampling (nil defaults to fastRand) } // ErrorFormatter defines a function for custom error message formatting. // Takes a slice of errors and returns a single formatted string. type ErrorFormatter func([]error) string // MultiErrorOption configures MultiError behavior during creation. type MultiErrorOption func(*MultiError) // NewMultiError creates a new MultiError instance with optional configuration. // Initial capacity is set to 4; applies options in the order provided. func NewMultiError(opts ...MultiErrorOption) *MultiError { m := &MultiError{ errors: make([]error, 0, 4), limit: 0, // Unlimited by default } for _, opt := range opts { opt(m) } return m } // Add appends an error to the collection with optional sampling, limit checks, and duplicate prevention. // Ignores nil errors and duplicates based on string equality; thread-safe. func (m *MultiError) Add(errs ...error) { if len(errs) == 0 { return } m.mu.Lock() defer m.mu.Unlock() for _, err := range errs { if err == nil { continue } // Check for duplicates by comparing error messages duplicate := false for _, e := range m.errors { if e.Error() == err.Error() { duplicate = true break } } if duplicate { continue } // Apply sampling if enabled and collection isn’t empty if m.sampling && len(m.errors) > 0 { var r uint32 if m.rand != nil { r = uint32(m.rand.Int31n(100)) } else { r = fastRand() % 100 } if r > m.sampleRate { // Accept if random value is within sample rate continue } } // Respect limit if set if m.limit > 0 && len(m.errors) >= m.limit { continue } m.errors = append(m.errors, err) } } // Addf formats and adds a new error to the collection. func (m *MultiError) Addf(format string, args ...interface{}) { m.Add(Newf(format, args...)) } // Clear removes all errors from the collection. // Thread-safe; resets the slice while preserving capacity. func (m *MultiError) Clear() { m.mu.Lock() defer m.mu.Unlock() m.errors = m.errors[:0] } // Count returns the number of errors in the collection. // Thread-safe. func (m *MultiError) Count() int { m.mu.RLock() defer m.mu.RUnlock() return len(m.errors) } // Error returns a formatted string representation of the errors. // Returns empty string if no errors, single error message if one exists, // or a formatted list using custom formatter or default if multiple; thread-safe. func (m *MultiError) Error() string { m.mu.RLock() defer m.mu.RUnlock() switch len(m.errors) { case 0: return "" case 1: return m.errors[0].Error() default: if m.formatter != nil { return m.formatter(m.errors) } return defaultFormat(m.errors) } } // Errors returns a copy of the contained errors. // Thread-safe; returns nil if no errors exist. func (m *MultiError) Errors() []error { m.mu.RLock() defer m.mu.RUnlock() if len(m.errors) == 0 { return nil } errs := make([]error, len(m.errors)) copy(errs, m.errors) return errs } // Filter returns a new MultiError containing only errors that match the predicate. // Thread-safe; preserves original configuration including limit, formatter, and sampling. func (m *MultiError) Filter(fn func(error) bool) *MultiError { m.mu.RLock() defer m.mu.RUnlock() var opts []MultiErrorOption opts = append(opts, WithLimit(m.limit)) if m.formatter != nil { opts = append(opts, WithFormatter(m.formatter)) } if m.sampling { opts = append(opts, WithSampling(m.sampleRate)) } filtered := NewMultiError(opts...) for _, err := range m.errors { if fn(err) { filtered.Add(err) } } return filtered } // First returns the first error in the collection, if any. // Thread-safe; returns nil if the collection is empty. func (m *MultiError) First() error { m.mu.RLock() defer m.mu.RUnlock() if len(m.errors) > 0 { return m.errors[0] } return nil } // Has reports whether the collection contains any errors. // Thread-safe. func (m *MultiError) Has() bool { m.mu.RLock() defer m.mu.RUnlock() return len(m.errors) > 0 } // Last returns the most recently added error in the collection, if any. // Thread-safe; returns nil if the collection is empty. func (m *MultiError) Last() error { m.mu.RLock() defer m.mu.RUnlock() if len(m.errors) > 0 { return m.errors[len(m.errors)-1] } return nil } // Merge combines another MultiError's errors into this one. // Thread-safe; respects this instance’s limit and sampling settings; no-op if other is nil or empty. func (m *MultiError) Merge(other *MultiError) { if other == nil || !other.Has() { return } // Snapshot other's errors under its own read lock, then release before // acquiring m's write lock inside Add. This prevents two bugs: // Self-merge deadlock: when m == other, holding other.mu.RLock then // calling m.Add (which takes m.mu.Lock) deadlocks on the same mutex. // Concurrent-write race: m had no lock protection during the loop, // so a concurrent Add on m could corrupt the slice. other.mu.RLock() snapshot := make([]error, len(other.errors)) copy(snapshot, other.errors) other.mu.RUnlock() m.Add(snapshot...) } // IsNull checks if the MultiError is empty or contains only null errors. // Returns true if empty or all errors are null (via IsNull() or empty message); thread-safe. func (m *MultiError) IsNull() bool { m.mu.RLock() defer m.mu.RUnlock() // Fast path for empty MultiError if len(m.errors) == 0 { return true } // Check each error for null status allNull := true for _, err := range m.errors { switch e := err.(type) { case interface{ IsNull() bool }: if !e.IsNull() { allNull = false break } case nil: continue default: if e.Error() != "" { allNull = false break } } } return allNull } // Single returns nil if the collection is empty, the single error if only one exists, // or the MultiError itself if multiple errors are present. // Thread-safe; useful for unwrapping to a single error when possible. func (m *MultiError) Single() error { m.mu.RLock() defer m.mu.RUnlock() switch len(m.errors) { case 0: return nil case 1: return m.errors[0] default: return m } } // String implements the Stringer interface for a concise string representation. // Thread-safe; delegates to Error() for formatting. func (m *MultiError) String() string { return m.Error() } // Unwrap returns a copy of the contained errors for multi-error unwrapping. // Implements the errors.Unwrap interface; thread-safe; returns nil if empty. func (m *MultiError) Unwrap() []error { return m.Errors() } // WithFormatter sets a custom error formatting function. // Returns a MultiErrorOption for use with NewMultiError; overrides default formatting. func WithFormatter(f ErrorFormatter) MultiErrorOption { return func(m *MultiError) { m.formatter = f } } // WithLimit sets the maximum number of errors to store. // Returns a MultiErrorOption for use with NewMultiError; 0 means unlimited, negative values are ignored. func WithLimit(n int) MultiErrorOption { return func(m *MultiError) { if n < 0 { n = 0 // Ensure non-negative limit } m.limit = n } } // WithSampling enables error sampling with a specified rate (1-100). // Returns a MultiErrorOption for use with NewMultiError; caps rate at 100 for validity. func WithSampling(rate uint32) MultiErrorOption { return func(m *MultiError) { if rate > 100 { rate = 100 } m.sampling = true m.sampleRate = rate } } // WithRand sets a custom random source for sampling, useful for testing. // Returns a MultiErrorOption for use with NewMultiError; defaults to fastRand if nil. func WithRand(r *rand.Rand) MultiErrorOption { return func(m *MultiError) { m.rand = r } } // MarshalJSON serializes the MultiError to JSON, including all contained errors and configuration metadata. // Thread-safe; errors are serialized using their MarshalJSON method if available, otherwise as strings. func (m *MultiError) MarshalJSON() ([]byte, error) { m.mu.RLock() defer m.mu.RUnlock() // Get buffer from pool. Do NOT use defer for Put — see errors.go MarshalJSON // for the full explanation. We must copy bytes out before returning the buf. buf := jsonBufferPool.Get().(*bytes.Buffer) buf.Reset() // Create encoder enc := json.NewEncoder(buf) enc.SetEscapeHTML(false) // Define JSON structure type jsonError struct { Error interface{} `json:"error"` // Holds either JSON-marshaled error or string } je := struct { Count int `json:"count"` // Number of errors Limit int `json:"limit,omitempty"` // Maximum error limit (omitted if 0) Sampling bool `json:"sampling,omitempty"` // Whether sampling is enabled SampleRate uint32 `json:"sample_rate,omitempty"` // Sampling rate (1-100, omitted if not sampling) Errors []jsonError `json:"errors"` // List of errors }{ Count: len(m.errors), Limit: m.limit, Sampling: m.sampling, SampleRate: m.sampleRate, } // Serialize each error je.Errors = make([]jsonError, len(m.errors)) for i, err := range m.errors { if err == nil { je.Errors[i] = jsonError{Error: nil} continue } // Check if the error implements json.Marshaler if marshaler, ok := err.(json.Marshaler); ok { // Use marshalErr (not err) to avoid shadowing the loop variable. marshaled, marshalErr := marshaler.MarshalJSON() if marshalErr != nil { // Fallback reports the ORIGINAL error message, not the marshal failure. je.Errors[i] = jsonError{Error: err.Error()} } else { var raw json.RawMessage = marshaled je.Errors[i] = jsonError{Error: raw} } } else { // Use error string for non-marshaler errors je.Errors[i] = jsonError{Error: err.Error()} } } // Encode JSON if err := enc.Encode(je); err != nil { return nil, fmt.Errorf("failed to marshal MultiError: %v", err) } // Copy out of buf's backing array before returning buf to pool. raw := buf.Bytes() if len(raw) > 0 && raw[len(raw)-1] == '\n' { raw = raw[:len(raw)-1] } result := make([]byte, len(raw)) copy(result, raw) jsonBufferPool.Put(buf) return result, nil } // defaultFormat provides the default formatting for multiple errors. // Returns a semicolon-separated list prefixed with the error count (e.g., "errors(3): err1; err2; err3"). func defaultFormat(errs []error) string { var sb strings.Builder sb.WriteString(fmt.Sprintf("errors(%d): ", len(errs))) for i, err := range errs { if i > 0 { sb.WriteString("; ") } sb.WriteString(err.Error()) } return sb.String() } // fastRand generates a quick pseudo-random number for sampling. // Uses a simple xorshift algorithm based on the current time; not cryptographically secure. var fastRandState uint32 = 1 // Must be non-zero func fastRand() uint32 { for { // Atomically load the current state old := atomic.LoadUint32(&fastRandState) // Xorshift computation x := old x ^= x << 13 x ^= x >> 17 x ^= x << 5 // Attempt to store the new state atomically if atomic.CompareAndSwapUint32(&fastRandState, old, x) { return x } // Otherwise retry } } golang-github-olekukonko-errors-1.3.0/multi_error_test.go000066400000000000000000000222201517267734700237050ustar00rootroot00000000000000package errors import ( "encoding/json" "errors" "fmt" "math/rand" "reflect" "testing" ) // TestMultiError_Basic verifies basic MultiError functionality. // Ensures empty creation, nil error handling, and single error addition work as expected. func TestMultiError_Basic(t *testing.T) { m := NewMultiError() if m.Has() { t.Error("New MultiError should be empty") } m.Add(nil) // Single nil error if m.Has() { t.Error("Adding nil should not create error") } err1 := errors.New("error 1") m.Add(err1) // Single error if !m.Has() { t.Error("Should detect errors after adding one") } if m.Count() != 1 { t.Errorf("Count should be 1, got %d", m.Count()) } if m.First() != err1 || m.Last() != err1 { t.Errorf("First() and Last() should both be %v, got First=%v, Last=%v", err1, m.First(), m.Last()) } // Test variadic Add with nil and duplicate m.Add(nil, err1, errors.New("error 1")) // Nil, duplicate, and same message if m.Count() != 1 { t.Errorf("Count should remain 1 after adding nil and duplicate, got %d", m.Count()) } } // TestMultiError_Sampling tests the sampling behavior of MultiError. // Adds many unique errors with a 50% sampling rate and checks the resulting ratio is within 45-55%. func TestMultiError_Sampling(t *testing.T) { r := rand.New(rand.NewSource(42)) // Fixed seed for reproducible results m := NewMultiError(WithSampling(50), WithRand(r)) total := 1000 // Add errors in batches to test variadic Add batchSize := 100 for i := 0; i < total; i += batchSize { batch := make([]error, batchSize) for j := 0; j < batchSize; j++ { batch[j] = errors.New(fmt.Sprintf("test%d", i+j)) // Unique errors } m.Add(batch...) } count := m.Count() ratio := float64(count) / float64(total) // Expect roughly 50% (±5%) due to sampling; adjust range if sampling logic changes if ratio < 0.45 || ratio > 0.55 { t.Errorf("Sampling ratio %v not within expected range (45-55%%), count=%d, total=%d", ratio, count, total) } } // TestMultiError_Limit tests the error limit enforcement of MultiError. // Adds twice the limit of unique errors and verifies the count caps at the limit. func TestMultiError_Limit(t *testing.T) { limit := 10 m := NewMultiError(WithLimit(limit)) // Add errors in a single variadic call errors := make([]error, limit*2) for i := 0; i < limit*2; i++ { errors[i] = New(fmt.Sprintf("test%d", i)) // Unique errors } m.Add(errors...) if m.Count() != limit { t.Errorf("Should cap at %d errors, got %d", limit, m.Count()) } } // TestMultiError_Formatting verifies custom formatting in MultiError. // Adds two errors and checks the custom formatter outputs the expected string. func TestMultiError_Formatting(t *testing.T) { customFormat := func(errs []error) string { return fmt.Sprintf("custom: %d", len(errs)) } m := NewMultiError(WithFormatter(customFormat)) m.Add(errors.New("test1"), errors.New("test2")) // Add two errors at once expected := "custom: 2" if m.Error() != expected { t.Errorf("Expected %q, got %q", expected, m.Error()) } } // TestMultiError_Filter tests the filtering functionality of MultiError. // Adds three errors, filters out one, and verifies the resulting count is correct. func TestMultiError_Filter(t *testing.T) { m := NewMultiError() m.Add(errors.New("error1"), errors.New("skip"), errors.New("error2")) // Variadic add filtered := m.Filter(func(err error) bool { return err.Error() != "skip" }) if filtered.Count() != 2 { t.Errorf("Should filter out one error, leaving 2, got %d", filtered.Count()) } } // TestMultiError_AsSingle tests the Single() method across different scenarios. // Verifies behavior for empty, single-error, and multi-error cases. func TestMultiError_AsSingle(t *testing.T) { // Subtest: Empty MultiError should return nil t.Run("Empty", func(t *testing.T) { m := NewMultiError() if m.Single() != nil { t.Errorf("Empty should return nil, got %v", m.Single()) } }) // Subtest: Single error should return that error t.Run("Single", func(t *testing.T) { m := NewMultiError() err := errors.New("test") m.Add(err) if m.Single() != err { t.Errorf("Should return single error %v, got %v", err, m.Single()) } }) // Subtest: Multiple errors should return the MultiError itself t.Run("Multiple", func(t *testing.T) { m := NewMultiError() m.Add(errors.New("test1"), errors.New("test2")) // Variadic add if m.Single() != m { t.Errorf("Should return self for multiple errors, got %v", m.Single()) } }) } // TestMultiError_MarshalJSON tests the JSON serialization of MultiError. // Verifies correct output for empty, single-error, multiple-error, and mixed-error cases. func TestMultiError_MarshalJSON(t *testing.T) { // Subtest: Empty t.Run("Empty", func(t *testing.T) { m := NewMultiError() data, err := json.Marshal(m) if err != nil { t.Fatalf("MarshalJSON failed: %v", err) } expected := `{"count":0,"errors":[]}` if string(data) != expected { t.Errorf("Expected %q, got %q", expected, string(data)) } }) // Subtest: Single standard error t.Run("SingleStandardError", func(t *testing.T) { m := NewMultiError() err := errors.New("timeout") m.Add(err) data, err := json.Marshal(m) if err != nil { t.Fatalf("MarshalJSON failed: %v", err) } expected := `{"count":1,"errors":[{"error":"timeout"}]}` var expectedJSON, actualJSON interface{} if err := json.Unmarshal([]byte(expected), &expectedJSON); err != nil { t.Fatalf("Failed to parse expected JSON: %v", err) } if err := json.Unmarshal(data, &actualJSON); err != nil { t.Fatalf("Failed to parse actual JSON: %v", err) } if !reflect.DeepEqual(expectedJSON, actualJSON) { t.Errorf("JSON output mismatch.\nGot: %s\nWant: %s", string(data), expected) } }) // Subtest: Multiple errors including *Error t.Run("MultipleMixedErrors", func(t *testing.T) { m := NewMultiError(WithLimit(5)) // No sampling to ensure all errors are added m.Add( New("db error").WithCode(500).With("user_id", 123), // *Error errors.New("timeout"), // Standard error nil, // Nil error (skipped by Add) ) data, err := json.Marshal(m) if err != nil { t.Fatalf("MarshalJSON failed: %v", err) } expected := `{ "count":2, "limit":5, "errors":[ {"error":{"message":"db error","context":{"user_id":123},"code":500}}, {"error":"timeout"} ] }` var expectedJSON, actualJSON interface{} if err := json.Unmarshal([]byte(expected), &expectedJSON); err != nil { t.Fatalf("Failed to parse expected JSON: %v", err) } if err := json.Unmarshal(data, &actualJSON); err != nil { t.Fatalf("Failed to parse actual JSON: %v", err) } if !reflect.DeepEqual(expectedJSON, actualJSON) { t.Errorf("JSON output mismatch.\nGot: %s\nWant: %s", string(data), expected) } }) // Subtest: Concurrent access to ensure thread safety t.Run("Concurrent", func(t *testing.T) { m := NewMultiError() err1 := New("error1").WithCode(400) err2 := errors.New("error2") m.Add(err1, err2) // Variadic add // Run multiple goroutines to marshal concurrently const numGoroutines = 10 results := make(chan []byte, numGoroutines) errorsChan := make(chan error, numGoroutines) for i := 0; i < numGoroutines; i++ { go func() { data, err := json.Marshal(m) if err != nil { errorsChan <- err return } results <- data }() } // Collect results expected := `{ "count":2, "errors":[ {"error":{"message":"error1","code":400}}, {"error":"error2"} ] }` var expectedJSON interface{} if err := json.Unmarshal([]byte(expected), &expectedJSON); err != nil { t.Fatalf("Failed to parse expected JSON: %v", err) } for i := 0; i < numGoroutines; i++ { select { case err := <-errorsChan: t.Errorf("Concurrent MarshalJSON failed: %v", err) case data := <-results: var actualJSON interface{} if err := json.Unmarshal(data, &actualJSON); err != nil { t.Errorf("Failed to parse actual JSON: %v", err) } if !reflect.DeepEqual(expectedJSON, actualJSON) { t.Errorf("Concurrent JSON output mismatch.\nGot: %s\nWant: %s", string(data), expected) } } } }) // Subtest: Variadic add with multiple errors t.Run("VariadicAdd", func(t *testing.T) { m := NewMultiError(WithLimit(10)) err1 := New("error1").WithCode(400) err2 := errors.New("error2") err3 := errors.New("error3") m.Add(err1, err2, err3, nil, err2) // Mix of unique, nil, and duplicate errors if m.Count() != 3 { t.Errorf("Expected 3 errors, got %d", m.Count()) } data, err := json.Marshal(m) if err != nil { t.Fatalf("MarshalJSON failed: %v", err) } expected := `{ "count":3, "limit":10, "errors":[ {"error":{"message":"error1","code":400}}, {"error":"error2"}, {"error":"error3"} ] }` var expectedJSON, actualJSON interface{} if err := json.Unmarshal([]byte(expected), &expectedJSON); err != nil { t.Fatalf("Failed to parse expected JSON: %v", err) } if err := json.Unmarshal(data, &actualJSON); err != nil { t.Fatalf("Failed to parse actual JSON: %v", err) } if !reflect.DeepEqual(expectedJSON, actualJSON) { t.Errorf("JSON output mismatch.\nGot: %s\nWant: %s", string(data), expected) } }) } golang-github-olekukonko-errors-1.3.0/pool.go000066400000000000000000000044361517267734700212650ustar00rootroot00000000000000package errors import ( "sync" "sync/atomic" ) // ErrorPool is a high-performance, thread-safe pool for reusing *Error instances. // Reduces allocation overhead by recycling errors; tracks hit/miss statistics. type ErrorPool struct { pool sync.Pool // Underlying pool for storing *Error instances poolStats struct { // Embedded struct for pool usage statistics hits atomic.Int64 // Number of times an error was reused from the pool misses atomic.Int64 // Number of times a new error was created due to pool miss } } // NewErrorPool creates a new ErrorPool instance. // Initializes the pool with a New function that returns a fresh *Error with default smallContext. func NewErrorPool() *ErrorPool { return &ErrorPool{ pool: sync.Pool{ New: func() interface{} { return &Error{ smallContext: [contextSize]contextItem{}, } }, }, } } // Get retrieves an *Error from the pool or creates a new one if pooling is disabled or pool is empty. // Resets are handled by Put; thread-safe; updates hit/miss stats when pooling is enabled. func (ep *ErrorPool) Get() *Error { if currentConfig.disablePooling { return &Error{ smallContext: [contextSize]contextItem{}, } } e := ep.pool.Get().(*Error) if e == nil { // Pool returned nil (unlikely due to New func, but handled for safety) ep.poolStats.misses.Add(1) e = &Error{ smallContext: [contextSize]contextItem{}, } ep.setupCleanup(e) return e } ep.poolStats.hits.Add(1) // Register auto-cleanup so GC can return the error to the pool if the // caller forgets to call Free(). If AutoFree is false this is a no-op. ep.setupCleanup(e) return e } // Put returns an *Error to the pool after resetting it. // Ignores nil errors or if pooling is disabled; preserves stack capacity; thread-safe. func (ep *ErrorPool) Put(e *Error) { if e == nil || currentConfig.disablePooling { return } // Reset the error to a clean state, preserving capacity e.Reset() // Reset stack length while keeping capacity for reuse if e.stack != nil { e.stack = e.stack[:0] } ep.pool.Put(e) } // Stats returns the current pool statistics as hits and misses. // Thread-safe; uses atomic loads to ensure accurate counts. func (ep *ErrorPool) Stats() (hits, misses int64) { return ep.poolStats.hits.Load(), ep.poolStats.misses.Load() } golang-github-olekukonko-errors-1.3.0/pool_above_1_24.go000066400000000000000000000020661517267734700231630ustar00rootroot00000000000000//go:build go1.24 // +build go1.24 package errors import "runtime" // setupCleanup registers a runtime.AddCleanup callback that returns e to the // pool when the GC determines e is unreachable — only when AutoFree is enabled. // // IMPORTANT: the cleanup argument must be e itself (passed as the third arg), // NOT captured in the closure. Capturing e in the closure creates a strong // reference from the cleanup to e, which prevents e from ever becoming // unreachable and defeats the purpose of the cleanup entirely. func (ep *ErrorPool) setupCleanup(e *Error) { if !currentConfig.autoFree || currentConfig.disablePooling { return } runtime.AddCleanup(e, func(target *Error) { if !currentConfig.disablePooling { ep.Put(target) } }, e) } // clearCleanup is a no-op for Go 1.24+. // runtime.AddCleanup does not support cancellation; the double-put risk is // mitigated by Free() resetting the error before Put, making a second Put // of an already-reset error safe (it just returns a clean object to the pool). func (ep *ErrorPool) clearCleanup(_ *Error) {} golang-github-olekukonko-errors-1.3.0/pool_below_1_24.go000066400000000000000000000016731517267734700232020ustar00rootroot00000000000000//go:build !go1.24 // +build !go1.24 package errors import "runtime" // setupCleanup registers a finalizer that returns e to the pool when the GC // collects it — only when AutoFree is enabled. // // Finalizer limitation: the GC may not collect the object promptly, and // finalizers run in a separate goroutine. This is acceptable for pool returns // since Put() is safe to call from any goroutine and Reset() is idempotent. func (ep *ErrorPool) setupCleanup(e *Error) { if !currentConfig.autoFree || currentConfig.disablePooling { return } runtime.SetFinalizer(e, func(target *Error) { if !currentConfig.disablePooling { ep.Put(target) } }) } // clearCleanup removes the finalizer so explicit Free() calls do not race // with a pending GC-triggered pool return (double-put). // This is the correct approach for pre-1.24 Go where finalizers can be cleared. func (ep *ErrorPool) clearCleanup(e *Error) { runtime.SetFinalizer(e, nil) } golang-github-olekukonko-errors-1.3.0/retry.go000066400000000000000000000253211517267734700214550ustar00rootroot00000000000000// Package errors provides utilities for error handling, including a flexible retry mechanism. package errors import ( "context" "math/rand" "time" ) // BackoffStrategy defines the interface for calculating retry delays. type BackoffStrategy interface { // Backoff returns the delay for a given attempt based on the base delay. Backoff(attempt int, baseDelay time.Duration) time.Duration } // ConstantBackoff provides a fixed delay for each retry attempt. type ConstantBackoff struct{} // Backoff returns the base delay regardless of the attempt number. // Implements BackoffStrategy with a constant delay. func (c ConstantBackoff) Backoff(_ int, baseDelay time.Duration) time.Duration { return baseDelay } // ExponentialBackoff provides an exponentially increasing delay for retry attempts. type ExponentialBackoff struct{} // Backoff returns a delay that doubles with each attempt, starting from the base delay. // Uses bit shifting for efficient exponential growth (e.g., baseDelay * 2^(attempt-1)). func (e ExponentialBackoff) Backoff(attempt int, baseDelay time.Duration) time.Duration { if attempt <= 1 { return baseDelay } return baseDelay * time.Duration(1< 0 && delay > r.maxDelay { delay = r.maxDelay } if r.jitter { delay = addJitter(delay) } // Wait with context select { case <-r.ctx.Done(): return r.ctx.Err() case <-time.After(delay): } } return lastErr } // ExecuteContext runs the provided function with retry logic, respecting context cancellation. // Returns nil on success or the last error if all attempts fail or context is cancelled. func (r *Retry) ExecuteContext(ctx context.Context, fn func() error) error { var lastErr error // If the retry instance already has a context, use it. Otherwise, use the provided one. // If both are provided, maybe create a derived context? For now, prioritize the one from WithContext. execCtx := r.ctx if execCtx == context.Background() && ctx != nil { // Use provided ctx if retry ctx is default and provided one isn't nil execCtx = ctx } else if ctx == nil { // Ensure we always have a non-nil context execCtx = context.Background() } // Note: This logic might need refinement depending on how contexts should interact. // A safer approach might be: if r.ctx != background, use it. Else use provided ctx. for attempt := 1; attempt <= r.maxAttempts; attempt++ { // Check context before executing the function select { case <-execCtx.Done(): return execCtx.Err() // Return context error immediately default: // Context is okay, proceed } err := fn() if err == nil { return nil // Success } // Check if retry is applicable based on the error if r.retryIf != nil && !r.retryIf(err) { return err // Not retryable, return the error } lastErr = err // Store the last encountered error // Execute the OnRetry callback if configured if r.onRetry != nil { r.onRetry(attempt, err) } // Exit loop if this was the last attempt if attempt == r.maxAttempts { break } // Calculate and apply delay currentDelay := r.backoff.Backoff(attempt, r.delay) if r.maxDelay > 0 && currentDelay > r.maxDelay { // Check maxDelay > 0 before capping currentDelay = r.maxDelay } if r.jitter { currentDelay = addJitter(currentDelay) } if currentDelay < 0 { // Ensure delay isn't negative after jitter currentDelay = 0 } // Wait for the delay or context cancellation select { case <-execCtx.Done(): // If context is cancelled during the wait, return the context error // Often more informative than returning the last application error. return execCtx.Err() case <-time.After(currentDelay): // Wait finished, continue to the next attempt } } // All attempts failed, return the last error encountered return lastErr } // Transform creates a new Retry instance with modified configuration. // Copies all settings from the original Retry and applies the given options. func (r *Retry) Transform(opts ...RetryOption) *Retry { newRetry := &Retry{ maxAttempts: r.maxAttempts, delay: r.delay, maxDelay: r.maxDelay, retryIf: r.retryIf, onRetry: r.onRetry, backoff: r.backoff, jitter: r.jitter, ctx: r.ctx, } for _, opt := range opts { opt(newRetry) } return newRetry } // WithBackoff sets the backoff strategy using the BackoffStrategy interface. // Returns a RetryOption; no-op if strategy is nil, retaining the existing strategy. func WithBackoff(strategy BackoffStrategy) RetryOption { return func(r *Retry) { if strategy != nil { r.backoff = strategy } } } // WithContext sets the context for cancellation and deadlines. // Returns a RetryOption; retains context.Background if ctx is nil. func WithContext(ctx context.Context) RetryOption { return func(r *Retry) { if ctx != nil { r.ctx = ctx } } } // WithDelay sets the initial delay between retries. // Returns a RetryOption; ensures non-negative delay by setting negatives to 0. func WithDelay(delay time.Duration) RetryOption { return func(r *Retry) { if delay < 0 { delay = 0 } r.delay = delay } } // WithJitter enables or disables jitter in the backoff delay. // Returns a RetryOption; toggles random delay variation. func WithJitter(jitter bool) RetryOption { return func(r *Retry) { r.jitter = jitter } } // WithMaxAttempts sets the maximum number of retry attempts. // Returns a RetryOption; ensures at least 1 attempt by adjusting lower values. func WithMaxAttempts(maxAttempts int) RetryOption { return func(r *Retry) { if maxAttempts < 1 { maxAttempts = 1 } r.maxAttempts = maxAttempts } } // WithMaxDelay sets the maximum delay between retries. // Returns a RetryOption; ensures non-negative delay by setting negatives to 0. func WithMaxDelay(maxDelay time.Duration) RetryOption { return func(r *Retry) { if maxDelay < 0 { maxDelay = 0 } r.maxDelay = maxDelay } } // WithOnRetry sets a callback to execute after each failed attempt. // Returns a RetryOption; callback receives attempt number and error. func WithOnRetry(onRetry func(attempt int, err error)) RetryOption { return func(r *Retry) { r.onRetry = onRetry } } // WithRetryIf sets the condition under which to retry. // Returns a RetryOption; retains IsRetryable default if retryIf is nil. func WithRetryIf(retryIf func(error) bool) RetryOption { return func(r *Retry) { if retryIf != nil { r.retryIf = retryIf } } } // ExecuteReply runs the provided function with retry logic and returns its result. // Returns the result and nil on success, or zero value and last error on failure; generic type T. func ExecuteReply[T any](r *Retry, fn func() (T, error)) (T, error) { var lastErr error var zero T for attempt := 1; attempt <= r.maxAttempts; attempt++ { result, err := fn() if err == nil { return result, nil } // Check if retry is applicable; return immediately if not retryable if r.retryIf != nil && !r.retryIf(err) { return zero, err } lastErr = err if r.onRetry != nil { r.onRetry(attempt, err) } if attempt == r.maxAttempts { break } // Calculate delay with backoff, cap at maxDelay, and apply jitter if enabled currentDelay := r.backoff.Backoff(attempt, r.delay) if currentDelay > r.maxDelay { currentDelay = r.maxDelay } if r.jitter { currentDelay = addJitter(currentDelay) } // Wait with respect to context cancellation or timeout select { case <-r.ctx.Done(): return zero, r.ctx.Err() case <-time.After(currentDelay): } } return zero, lastErr } golang-github-olekukonko-errors-1.3.0/retry_test.go000066400000000000000000000063351517267734700225200ustar00rootroot00000000000000package errors import ( "context" "math/rand" "testing" "time" ) func init() { rand.Seed(time.Now().UnixNano()) // Ensure jitter randomness } // TestExecuteReply_Success tests successful execution after retries with a string result. func TestExecuteReply_Success(t *testing.T) { retry := NewRetry( WithMaxAttempts(3), WithDelay(50*time.Millisecond), WithBackoff(LinearBackoff{}), WithJitter(false), ) calls := 0 start := time.Now() result, err := ExecuteReply[string](retry, func() (string, error) { calls++ if calls < 2 { return "", New("temporary error").WithRetryable() } return "success", nil }) duration := time.Since(start) if err != nil { t.Errorf("Expected no error, got %v", err) } if result != "success" { t.Errorf("Expected 'success', got %q", result) } if calls != 2 { t.Errorf("Expected 2 calls, got %d", calls) } if duration < 45*time.Millisecond { // Slightly less than 50ms for execution overhead t.Errorf("Expected at least 50ms delay, got %v", duration) } } func TestExecuteReply_Failure(t *testing.T) { retry := NewRetry( WithMaxAttempts(2), WithDelay(10*time.Millisecond), ) calls := 0 result, err := ExecuteReply[int](retry, func() (int, error) { calls++ return 0, New("persistent error").WithRetryable() }) if err == nil { t.Error("Expected error, got nil") } if result != 0 { t.Errorf("Expected zero value (0), got %d", result) } if calls != 2 { t.Errorf("Expected 2 calls, got %d", calls) } } func TestExecuteReply_NonRetryable(t *testing.T) { retry := NewRetry(WithMaxAttempts(3)) calls := 0 result, err := ExecuteReply[float64](retry, func() (float64, error) { calls++ return 0.0, New("fatal error") // Not retryable }) if err == nil { t.Error("Expected error, got nil") } if result != 0.0 { t.Errorf("Expected zero value (0.0), got %f", result) } if calls != 1 { t.Errorf("Expected 1 call, got %d", calls) } } func TestExecuteReply_ContextCancellation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) retry := NewRetry( WithMaxAttempts(5), WithContext(ctx), WithDelay(50*time.Millisecond), ) calls := 0 go func() { time.Sleep(125 * time.Millisecond) // Allow 2 calls (100ms total) before cancel cancel() }() result, err := ExecuteReply[string](retry, func() (string, error) { calls++ time.Sleep(25 * time.Millisecond) // Simulate work return "", New("retryable error").WithRetryable() }) if !Is(err, context.Canceled) { t.Errorf("Expected context canceled error, got %v", err) } if result != "" { t.Errorf("Expected zero value (\"\"), got %q", result) } if calls < 2 { t.Errorf("Expected at least 2 calls before cancellation, got %d", calls) } } func TestExecuteReply_DifferentTypes(t *testing.T) { type Result struct { Value int } retry := NewRetry(WithMaxAttempts(3)) calls := 0 result, err := ExecuteReply[Result](retry, func() (Result, error) { calls++ if calls < 2 { return Result{}, New("temporary error").WithRetryable() } return Result{Value: 42}, nil }) if err != nil { t.Errorf("Expected no error, got %v", err) } if result.Value != 42 { t.Errorf("Expected Value 42, got %d", result.Value) } if calls != 2 { t.Errorf("Expected 2 calls, got %d", calls) } } golang-github-olekukonko-errors-1.3.0/sentinel.go000066400000000000000000000103251517267734700221270ustar00rootroot00000000000000// Comparable, immutable sentinel errors for package-level error variables. // // Relationship to errmgr.Define // The errmgr subpackage provides a PARAMETERISED error factory: // // var ErrDefined = errmgr.Define("ErrTimeout", "operation timed out after %s: %s") // err := ErrDefined.New("5s", "dial failed") // produces a formatted *Error each call // // That is for creating new error instances from a template at call sites. // // errors.Const (this file) creates a STATIC SENTINEL — a single stable pointer // stored once as a package-level variable and compared with errors.Is: // // var ErrNotFound = errors.Const("not_found", "resource not found") // // if errors.Is(err, ErrNotFound) { ... } // pointer equality, always correct // // Use errmgr.Define when you need to produce many errors from a format template. // Use errors.Const when you need a fixed comparable value for Is/switch matching. package errors import ( "encoding/json" "fmt" "log/slog" ) // Sentinel is a comparable, immutable error value safe to store as a // package-level variable and match with errors.Is or a type switch. // // Unlike Named(), which returns a new *Error instance on every call (making // pointer equality unreliable), each call to Const() returns a unique stable // pointer. Two sentinels with identical name/msg are still distinct values // unless they are the same pointer — intentional, to avoid accidental aliasing. type Sentinel struct { name string msg string } // Error implements the error interface. func (s *Sentinel) Error() string { return s.msg } // Is reports whether target is the same sentinel (pointer equality). // This satisfies the errors.Is contract. func (s *Sentinel) Is(target error) bool { t, ok := target.(*Sentinel) return ok && s == t } // As attempts to assign the sentinel to target if target is **Sentinel. // Returns true if the assignment was made. func (s *Sentinel) As(target any) bool { if tp, ok := target.(**Sentinel); ok { *tp = s return true } return false } // Unwrap returns nil — sentinels are root errors with no cause chain. // Satisfies the errors.Unwrap contract. func (s *Sentinel) Unwrap() error { return nil } // Name returns the sentinel's name, useful for logging and diagnostics. func (s *Sentinel) Name() string { return s.name } // String returns a debug-friendly representation. func (s *Sentinel) String() string { return fmt.Sprintf("Sentinel(%s: %s)", s.name, s.msg) } // LogValue implements slog.LogValuer so a Sentinel can be passed directly // to any slog logging call and rendered as a structured group. // // Example: // // slog.Error("lookup failed", "err", ErrNotFound) // // => err.error="resource not found", err.code="not_found" func (s *Sentinel) LogValue() slog.Value { return slog.GroupValue( slog.String("error", s.msg), slog.String("code", s.name), ) } // MarshalJSON serialises the sentinel to {"error":"...","code":"..."}. func (s *Sentinel) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Error string `json:"error"` Code string `json:"code"` }{ Error: s.msg, Code: s.name, }) } // With returns a new *Error that wraps this sentinel as its cause and carries // the additional message msg. Use this to add call-site context to a sentinel // without losing the ability to match the original with errors.Is. // // Example: // // var ErrNotFound = errors.Const("not_found", "resource not found") // // // At call site: // err := ErrNotFound.With("user 42 not found") // errors.Is(err, ErrNotFound) // true — sentinel is in the cause chain func (s *Sentinel) With(msg string) *Error { e := New(msg) e.cause = s return e } // Const creates a new sentinel error with the given name and message. // Store the result as a package-level var; never call Const in a hot path. // // Example: // // var ( // ErrNotFound = errors.Const("not_found", "resource not found") // ErrForbidden = errors.Const("forbidden", "access denied") // ErrBadRequest = errors.Const("bad_request", "invalid input") // ) // // func handle(err error) { // switch { // case errors.Is(err, ErrNotFound): // 404 // case errors.Is(err, ErrForbidden): // 403 // } // } func Const(name, msg string) *Sentinel { return &Sentinel{name: name, msg: msg} } golang-github-olekukonko-errors-1.3.0/sentinel_test.go000066400000000000000000000113511517267734700231660ustar00rootroot00000000000000package errors import ( "encoding/json" "log/slog" "strings" "testing" ) func TestConstCreatesUniquePointers(t *testing.T) { a := Const("not_found", "resource not found") b := Const("not_found", "resource not found") if a == b { t.Error("Const() should return distinct pointers on each call") } } func TestConstIsComparable(t *testing.T) { ErrNotFound := Const("not_found", "resource not found") if !Is(ErrNotFound, ErrNotFound) { t.Error("Is(sentinel, sentinel) should be true") } wrapped := New("request failed").Wrap(ErrNotFound) if !Is(wrapped, ErrNotFound) { t.Error("Is should find sentinel through a wrapped *Error chain") } } func TestConstDoesNotMatchDifferentSentinel(t *testing.T) { ErrA := Const("a", "error a") ErrB := Const("b", "error b") if Is(ErrA, ErrB) { t.Error("two distinct sentinels should not match each other") } } func TestConstError(t *testing.T) { s := Const("validation_failed", "input is invalid") if s.Error() != "input is invalid" { t.Errorf("Error() = %q, want %q", s.Error(), "input is invalid") } } func TestConstName(t *testing.T) { s := Const("my_error", "something happened") if s.Name() != "my_error" { t.Errorf("Name() = %q, want %q", s.Name(), "my_error") } } func TestConstDoesNotMatchPlainError(t *testing.T) { s := Const("sentinel", "sentinel error") other := New("different message") if Is(s, other) { t.Error("sentinel should not match a *Error with a different message") } // Note: Is(sentinel, target) uses pointer equality so is always false // for non-identical sentinels regardless of message content. } func TestConstImplementsError(t *testing.T) { var _ error = Const("x", "y") } // Unwrap func TestSentinelUnwrap(t *testing.T) { s := Const("root", "root cause") if s.Unwrap() != nil { t.Error("Sentinel.Unwrap() should return nil — sentinels are root errors") } } // As func TestSentinelAs(t *testing.T) { ErrNotFound := Const("not_found", "resource not found") wrapped := New("handler failed").Wrap(ErrNotFound) var target *Sentinel if !As(wrapped, &target) { t.Fatal("As() should find the Sentinel in the cause chain") } if target != ErrNotFound { t.Error("As() should set target to the exact sentinel pointer") } } func TestSentinelAsWrongType(t *testing.T) { s := Const("x", "x") var target *Error if As(s, &target) { t.Error("As() should return false when target type does not match") } } // String func TestSentinelString(t *testing.T) { s := Const("not_found", "resource not found") got := s.String() if !strings.Contains(got, "not_found") || !strings.Contains(got, "resource not found") { t.Errorf("String() = %q — expected name and message", got) } } // LogValue func TestSentinelLogValue(t *testing.T) { s := Const("auth_error", "authentication failed") val := s.LogValue() if val.Kind() != slog.KindGroup { t.Errorf("LogValue() kind = %v, want Group", val.Kind()) } attrs := val.Group() keys := make(map[string]string, len(attrs)) for _, a := range attrs { keys[a.Key] = a.Value.String() } if keys["error"] != "authentication failed" { t.Errorf("LogValue error attr = %q, want %q", keys["error"], "authentication failed") } if keys["code"] != "auth_error" { t.Errorf("LogValue code attr = %q, want %q", keys["code"], "auth_error") } } // MarshalJSON func TestSentinelMarshalJSON(t *testing.T) { s := Const("not_found", "resource not found") b, err := json.Marshal(s) if err != nil { t.Fatalf("MarshalJSON() error: %v", err) } var out struct { Error string `json:"error"` Code string `json:"code"` } if err := json.Unmarshal(b, &out); err != nil { t.Fatalf("Unmarshal error: %v", err) } if out.Error != "resource not found" { t.Errorf("JSON error = %q, want %q", out.Error, "resource not found") } if out.Code != "not_found" { t.Errorf("JSON code = %q, want %q", out.Code, "not_found") } } // With func TestSentinelWith(t *testing.T) { ErrNotFound := Const("not_found", "resource not found") err := ErrNotFound.With("user 42 not found") // The returned *Error should carry the call-site message. if err.Error() != "user 42 not found: resource not found" && !strings.Contains(err.Error(), "user 42 not found") { t.Errorf("With() message = %q, want it to contain call-site context", err.Error()) } // The original sentinel must still be findable via Is. if !Is(err, ErrNotFound) { t.Error("Is(With(...), sentinel) should be true — sentinel is the cause") } } func TestSentinelWithPreservesChain(t *testing.T) { ErrForbidden := Const("forbidden", "access denied") err := ErrForbidden.With("route /admin requires admin role") var s *Sentinel if !As(err, &s) { t.Fatal("As() should find Sentinel through With() chain") } if s != ErrForbidden { t.Error("As() should return the original sentinel pointer") } } golang-github-olekukonko-errors-1.3.0/utils.go000066400000000000000000000110621517267734700214450ustar00rootroot00000000000000package errors import ( "database/sql" "fmt" "reflect" "runtime" "strings" ) // captureStack captures a stack trace with the configured depth. // Immune to inlining: captures from frame 1 and trims by shifting within // the same buffer so the pooled slice always retains its full capacity. func captureStack(skip int) []uintptr { buf := stackPool.Get().([]uintptr) buf = buf[:cap(buf)] // Capture from frame 1 (skipping runtime.Callers itself). // captureStack can never be inlined because it calls runtime.Callers, // so buf[0] is always captureStack regardless of compiler inlining above. n := runtime.Callers(1, buf) if n == 0 { stackPool.Put(buf) return nil } // Trim leading internal frames in-place using copy, preserving the // buffer's full capacity so the pool never fills with shrinking slices. // skip+1: +1 for captureStack itself (always buf[0]). trimmed := skip + 1 if trimmed >= n { stackPool.Put(buf) return nil } length := n - trimmed // Shift the useful frames to the start of buf — same backing array, // same capacity, zero allocation. copy(buf, buf[trimmed:n]) return buf[:length] } // min returns the smaller of two integers. func min(a, b int) int { if a < b { return a } return b } // clearMap removes all entries from a map without reallocating it. func clearMap(m map[string]interface{}) { for k := range m { delete(m, k) } } // sqlNull detects if a value represents a SQL NULL type. func sqlNull(v interface{}) bool { if v == nil { return true } switch val := v.(type) { case sql.NullString: return !val.Valid case sql.NullTime: return !val.Valid case sql.NullInt64: return !val.Valid case sql.NullBool: return !val.Valid case sql.NullFloat64: return !val.Valid default: return false } } // getFuncName extracts the function name from an interface value. // Returns "unknown" if the input is nil or invalid. func getFuncName(fn interface{}) string { if fn == nil { return "unknown" } fullName := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() return strings.TrimPrefix(fullName, ".") } // isInternalFrame reports whether a stack frame belongs to this library's // internals and should be filtered from user-visible stack traces. // // Rules: // - runtime.* and reflect.* are always internal. // - _test.go files are NEVER internal: test functions must survive // filtering so that assertions like "stack contains testing.tRunner" // and "stack contains TestErrorTraceStackContent" can pass. // - Source files under github.com/olekukonko/errors/ (errors.go, utils.go, // helper.go, retry.go, multi_error.go) are internal. func isInternalFrame(frame runtime.Frame) bool { if strings.HasPrefix(frame.Function, "runtime.") || strings.HasPrefix(frame.Function, "reflect.") { return true } // Exempt test files before the path-prefix check: errors_test.go lives // at github.com/olekukonko/errors/errors_test.go which contains the // "errors" suffix and would otherwise be incorrectly filtered. if strings.HasSuffix(frame.File, "_test.go") { return false } suffixes := []string{ "errors", "utils", "helper", "retry", "multi", } for _, v := range suffixes { if strings.Contains(frame.File, fmt.Sprintf("github.com/olekukonko/errors/%s", v)) { return true } } return false } // FormatError returns a formatted string representation of an error. func FormatError(err error) string { if err == nil { return "" } var sb strings.Builder if e, ok := err.(*Error); ok { sb.WriteString(fmt.Sprintf("Error: %s\n", e.Error())) if e.name != "" { sb.WriteString(fmt.Sprintf("Name: %s\n", e.name)) } if ctx := e.Context(); len(ctx) > 0 { sb.WriteString("Context:\n") for k, v := range ctx { sb.WriteString(fmt.Sprintf("\t%s: %v\n", k, v)) } } if stack := e.Stack(); len(stack) > 0 { sb.WriteString("Stack Trace:\n") for _, frame := range stack { sb.WriteString(fmt.Sprintf("\t%s\n", frame)) } } if e.cause != nil { sb.WriteString(fmt.Sprintf("Caused by: %s\n", FormatError(e.cause))) } } else { sb.WriteString(fmt.Sprintf("Error: %s\n", err.Error())) } return sb.String() } // Caller returns the file, line, and function name of the caller at skip level. // Skip=0 returns the caller of this function, 1 returns its caller, etc. func Caller(skip int) (file string, line int, function string) { configMu.RLock() defer configMu.RUnlock() var pcs [1]uintptr n := runtime.Callers(skip+2, pcs[:]) if n == 0 { return "", 0, "unknown" } frame, _ := runtime.CallersFrames(pcs[:n]).Next() return frame.File, frame.Line, frame.Function }